diff --git a/atmat/lattice/atdivelem.m b/atmat/lattice/atdivelem.m index 1ab8899d7..b036fe96f 100644 --- a/atmat/lattice/atdivelem.m +++ b/atmat/lattice/atdivelem.m @@ -37,8 +37,8 @@ line=atsetfieldvalues(line,'ExitAngle',0.0); end if isfield(elem,'KickAngle') - line=atsetfieldvalues(line,'KickAngle',{1,1},el.KickAngle(1,1)*frac(:)/sum(frac)); - line=atsetfieldvalues(line,'KickAngle',{1,2},el.KickAngle(1,2)*frac(:)/sum(frac)); + line=atsetfieldvalues(line,'KickAngle',{1},el.KickAngle(1)*frac(:)/sum(frac)); + line=atsetfieldvalues(line,'KickAngle',{2},el.KickAngle(2)*frac(:)/sum(frac)); end line{1}=mvfield(line{1},entrancef); % Set back entrance fields diff --git a/atmat/lattice/atloadlattice.m b/atmat/lattice/atloadlattice.m index 04260fe06..56e7439a9 100644 --- a/atmat/lattice/atloadlattice.m +++ b/atmat/lattice/atloadlattice.m @@ -21,12 +21,16 @@ % the variable name must be specified using the 'matkey' keyword. % % .m Matlab function. The function must output a valid AT structure. +% .json JSON file +% +%see also atwritem, atwritejson persistent link_table if isempty(link_table) link_table.mat=@load_mat; link_table.m=@load_m; + link_table.json=@load_json; end [~,~,fext]=fileparts(fspec); @@ -57,7 +61,7 @@ dt=load(fpath); vnames=fieldnames(dt); key='RING'; - if length(vnames) == 1 + if isscalar(vnames) key=vnames{1}; else for v={'ring','lattice'} @@ -75,7 +79,31 @@ error('AT:load','Cannot find variable %s\nmatkey must be in: %s',... key, strjoin(vnames,', ')); end + end + function [lattice, opts]=load_json(fpath, opts) + data=jsondecode(fileread(fpath)); + % File signature for later use + try + atjson=data.atjson; + catch + atjson=1; + end + props=data.properties; + name=props.name; + energy=props.energy; + periodicity=props.periodicity; + particle=atparticle.loadobj(props.particle); + harmnumber=props.harmonic_number; + props=rmfield(props,{'name','energy','periodicity','particle','harmonic_number'}); + args=[fieldnames(props) struct2cell(props)]'; + lattice=atSetRingProperties(data.elements,... + 'FamName', name,... + 'Energy', energy,... + 'Periodicity', periodicity,... + 'Particle', particle,... + 'HarmNumber', harmnumber, ... + args{:}); end end \ No newline at end of file diff --git a/atmat/lattice/atwritejson.m b/atmat/lattice/atwritejson.m new file mode 100644 index 000000000..2ac8bfabe --- /dev/null +++ b/atmat/lattice/atwritejson.m @@ -0,0 +1,68 @@ +function varargout=atwritejson(ring, varargin) +%ATWRITEJSON Create a JSON file to store an AT lattice +% +%JS=ATWRITEJSON(RING) +% Return the JSON representation of RING as a character array +% +%ATWRITEJSON(RING, FILENAME) +% Write the JSON representation of RING to the file FILENAME +% +%ATWRITEJSON(RING, ..., 'compact', true) +% If compact is true, write a compact JSON file (no linefeeds) +% +%see also atloadlattice + +[compact, varargs]=getoption(varargin, 'compact', false); +[filename, ~]=getargs(varargs,[]); + +if ~isempty(filename) + %get filename information + [pname,fname,ext]=fileparts(filename); + + %Check file extension + if isempty(ext), ext='.json'; end + + % Open file to be written + [fid,mess]=fopen(fullfile(pname,[fname ext]),'wt'); + + if fid==-1 + error('AT:FileErr','Cannot Create file %s\n%s',fn,mess); + else + fprintf(fid, sjson(ring)); + fclose(fid); + end + varargout={}; +else + varargout={sjson(ring)}; +end + + function jsondata=sjson(ring) + ok=~atgetcells(ring, 'Class', 'RingParam'); + data.atjson= 1; + data.elements=ring(ok); + data.properties=get_params(ring); + jsondata=jsonencode(data, 'PrettyPrint', ~compact); + end + + function prms=get_params(ring) + % Get "standard" properties + [name, energy, part, periodicity, harmonic_number]=... + atGetRingProperties(ring,'FamName', 'Energy', 'Particle',... + 'Periodicity', 'HarmNumber'); + prms=struct('name', name, 'energy', energy, 'periodicity', periodicity,... + 'particle', saveobj(part), 'harmonic_number', harmonic_number); + % Add user-defined properties + idx=atlocateparam(ring); + if ~isempty(idx) + flist={'FamName','PassMethod','Length','Class',... + 'Energy', 'Particle','Periodicity','cell_harmnumber'}; + present=isfield(ring{idx}, flist); + p2=rmfield(ring{idx},flist(present)); + for nm=fieldnames(p2)' + na=nm{1}; + prms.(na)=p2.(na); + end + end + end + +end \ No newline at end of file diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 4dbfff9bd..9a52c5e9a 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -302,26 +302,17 @@ def __setattr__(self, key, value): super(Element, self).__setattr__(key, value) def __str__(self): - first3 = ["FamName", "Length", "PassMethod"] - # Get values and parameter objects - attrs = dict(self.items()) - keywords = [f"\t{k} : {attrs.pop(k)!s}" for k in first3] - keywords += [f"\t{k} : {v!s}" for k, v in attrs.items()] - return "\n".join((type(self).__name__ + ":", "\n".join(keywords))) + return "\n".join( + [self.__class__.__name__ + ":"] + + [f"{k:>14}: {v!s}" for k, v in self.items()] + ) def __repr__(self): - # Get values only, even for parameters - attrs = dict((k, getattr(self, k)) for k, v in self.items()) - arguments = [attrs.pop(k) for k in self._BUILD_ATTRIBUTES] - defelem = self.__class__(*arguments) - keywords = [f"{v!r}" for v in arguments] - keywords += [ - f"{k}={v!r}" - for k, v in sorted(attrs.items()) - if not numpy.array_equal(v, getattr(defelem, k, None)) - ] + clsname, args, kwargs = self.definition + keywords = [f"{arg!r}" for arg in args] + keywords += [f"{k}={v!r}" for k, v in kwargs.items()] args = re.sub(r"\n\s*", " ", ", ".join(keywords)) - return "{0}({1})".format(self.__class__.__name__, args) + return f"{clsname}({args})" def equals(self, other) -> bool: """Whether an element is equivalent to another. @@ -399,10 +390,28 @@ def deepcopy(self) -> Element: """Return a deep copy of the element""" return deepcopy(self) + @property + def definition(self) -> tuple[str, tuple, dict]: + """tuple (class_name, args, kwargs) defining the element""" + attrs = dict(self.items()) + arguments = tuple(attrs.pop( + k, getattr(self, k)) for k in self._BUILD_ATTRIBUTES + ) + defelem = self.__class__(*arguments) + keywords = dict( + (k, v) + for k, v in attrs.items() + if not numpy.array_equal(v, getattr(defelem, k, None)) + ) + return self.__class__.__name__, arguments, keywords + def items(self) -> Generator[tuple[str, Any], None, None]: """Iterates through the data members""" - # Properties may be added by overloading this method - yield from vars(self).items() + v = vars(self).copy() + for k in ["FamName", "Length", "PassMethod"]: + yield k, v.pop(k) + for k, v in sorted(v.items()): + yield k, v def is_compatible(self, other: Element) -> bool: """Checks if another :py:class:`Element` can be merged""" diff --git a/pyat/at/load/__init__.py b/pyat/at/load/__init__.py index 28c187e3e..196af9f2a 100644 --- a/pyat/at/load/__init__.py +++ b/pyat/at/load/__init__.py @@ -8,3 +8,4 @@ from .reprfile import * from .tracy import * from .elegant import * +from .json import * diff --git a/pyat/at/load/allfiles.py b/pyat/at/load/allfiles.py index be051ef32..cb3b154cf 100644 --- a/pyat/at/load/allfiles.py +++ b/pyat/at/load/allfiles.py @@ -1,10 +1,16 @@ """Generic function to save and load python AT lattices. The format is determined by the file extension """ + +from __future__ import annotations + +__all__ = ["load_lattice", "save_lattice", "register_format"] + import os.path -from at.lattice import Lattice +from collections.abc import Callable +from typing import Optional -__all__ = ['load_lattice', 'save_lattice', 'register_format'] +from at.lattice import Lattice _load_extension = {} _save_extension = {} @@ -13,41 +19,31 @@ def load_lattice(filepath: str, **kwargs) -> Lattice: """Load a Lattice object from a file -The file format is indicated by the filepath extension. - -Parameters: - filepath: Name of the file - -Keyword Args: - name (str): Name of the lattice. - Default: taken from the file, or ``''`` - energy (float): Energy of the lattice - (default: taken from the file) - periodicity (int]): Number of periods - (default: taken from the file, or 1) - *: All other keywords will be set as :py:class:`.Lattice` - attributes - -Specific keywords for .mat files - -Keyword Args: - mat_key (str): Name of the Matlab variable containing - the lattice. Default: Matlab variable name if there is only one, - otherwise ``'RING'`` - check (bool): Run coherence tests. Default: :py:obj:`True` - quiet (bool): Suppress the warning for non-standard classes. - Default: :py:obj:`False` - keep_all (bool): Keep Matlab RingParam elements as Markers. - Default: :py:obj:`False` - -Returns: - lattice (Lattice): New :py:class:`.Lattice` object - -See Also: - :py:func:`.load_mat`, :py:func:`.load_m`, :py:func:`.load_repr`, - :py:func:`.load_elegant`, :py:func:`.load_tracy` - -.. Admonition:: Known extensions are: + The file format is indicated by the filepath extension. The file name is stored in + the *in_file* Lattice attribute. The selected variable, if relevant, is stored + in the *use* Lattice attribute. + + Parameters: + filepath: Name of the file + + Keyword Args: + use (str): Name of the variable containing the desired lattice. + Default: if there is a single variable, use it, otherwise select ``"RING"`` + name (str): Name of the lattice. + Default: taken from the file, or ``""`` + energy (float): Energy of the lattice + (default: taken from the file) + periodicity (int): Number of periods + (default: taken from the file, or 1) + *: All other keywords will be set as :py:class:`.Lattice` + attributes + + Returns: + lattice (Lattice): New :py:class:`.Lattice` object + + Check the format-specific function for specific keyword arguments: + + .. Admonition:: Known extensions are: """ _, ext = os.path.splitext(filepath) try: @@ -58,25 +54,18 @@ def load_lattice(filepath: str, **kwargs) -> Lattice: return load_func(filepath, **kwargs) -def save_lattice(ring: Lattice, filepath: str, **kwargs): +def save_lattice(ring: Lattice, filepath: str, **kwargs) -> None: """Save a Lattice object -The file format is indicated by the filepath extension. - -Parameters: - ring: Lattice description - filepath: Name of the file + The file format is indicated by the filepath extension. -Specific keywords for .mat files - -Keyword Args: - mat_key (str): Name of the Matlab variable containing the lattice. - Default: ``'RING'`` + Parameters: + ring: Lattice description + filepath: Name of the file -See Also: - :py:func:`.save_mat`, :py:func:`.save_m`, :py:func:`.save_repr` + Check the format-specific function for specific keyword arguments: -.. Admonition:: Known extensions are: + .. Admonition:: Known extensions are: """ _, ext = os.path.splitext(filepath) try: @@ -87,24 +76,26 @@ def save_lattice(ring: Lattice, filepath: str, **kwargs): return save_func(ring, filepath, **kwargs) -def register_format(extension: str, load_func=None, save_func=None, - descr: str = ''): +def register_format( + extension: str, + load_func: Optional[Callable[..., Lattice]] = None, + save_func: Optional[Callable[..., None]] = None, + descr: str = "", +): """Register format-specific processing functions Parameters: extension: File extension string. - load_func: load function. Default: :py:obj:`None` - save_func: save_lattice function Default: :py:obj:`None` - descr: File type description + load_func: load function. + save_func: save function. + descr: File type description. """ if load_func is not None: _load_extension[extension] = load_func - load_lattice.__doc__ += '\n {0:<10}'\ - '\n {1}\n'.format(extension, descr) + load_lattice.__doc__ += f"\n {extension:<10}\n {descr}\n" if save_func is not None: _save_extension[extension] = save_func - save_lattice.__doc__ += '\n {0:<10}'\ - '\n {1}\n'.format(extension, descr) + save_lattice.__doc__ += f"\n {extension:<10}\n {descr}\n" Lattice.load = staticmethod(load_lattice) diff --git a/pyat/at/load/elegant.py b/pyat/at/load/elegant.py index ebb99835a..237824ae7 100644 --- a/pyat/at/load/elegant.py +++ b/pyat/at/load/elegant.py @@ -337,10 +337,8 @@ def load_elegant(filename: str, **kwargs) -> Lattice: name (str): Name of the lattice. Default: taken from the file. energy (float): Energy of the lattice [eV] - periodicity(int): Number of periods. Default: taken from the - elements, or 1 - *: All other keywords will be set as Lattice - attributes + periodicity(int): Number of periods. Default: taken from the elements, or 1 + *: All other keywords will be set as Lattice attributes Returns: lattice (Lattice): New :py:class:`.Lattice` object @@ -354,7 +352,7 @@ def load_elegant(filename: str, **kwargs) -> Lattice: harmonic_number = kwargs.pop("harmonic_number") def elem_iterator(params, elegant_file): - with open(params.setdefault("elegant_file", elegant_file)) as f: + with open(params.setdefault("in_file", elegant_file)) as f: contents = f.read() element_lines = expand_elegant( contents, lattice_key, energy, harmonic_number @@ -370,4 +368,5 @@ def elem_iterator(params, elegant_file): 'lattice {}: {}'.format(filename, e)) -register_format(".lte", load_elegant, descr="Elegant format") +register_format( + ".lte", load_elegant, descr="Elegant format. See :py:func:`.load_elegant`.") diff --git a/pyat/at/load/json.py b/pyat/at/load/json.py new file mode 100644 index 000000000..12453e4ac --- /dev/null +++ b/pyat/at/load/json.py @@ -0,0 +1,109 @@ +""" +Handling of JSON files +""" + +from __future__ import annotations + +__all__ = ["save_json", "load_json"] + +from os.path import abspath +import json +from typing import Optional, Any + +import numpy as np + +from .allfiles import register_format +from .utils import keep_elements, keep_attributes +from ..lattice import Element, Lattice, Particle + + +class _AtEncoder(json.JSONEncoder): + """JSON encoder for specific AT types""" + + def default(self, obj): + if isinstance(obj, Element): + return obj.to_dict() + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, Particle): + return obj.to_dict() + else: + return super().default(obj) + + +def save_json( + ring: Lattice, filename: Optional[str] = None, compact: bool = False +) -> None: + """Save a :py:class:`.Lattice` as a JSON file + + Parameters: + ring: Lattice description + filename: Name of the JSON file. Default: outputs on + :py:obj:`sys.stdout` + compact: If :py:obj:`False` (default), the JSON file is pretty-printed + with line feeds and indentation. Otherwise, the output is a single line. + + See Also: + :py:func:`.save_lattice` for a generic lattice-saving function. + :py:meth:`.Lattice.save` for a generic lattice-saving method. + """ + indent = None if compact else 2 + data = dict( + atjson=1, elements=list(keep_elements(ring)), properties=keep_attributes(ring) + ) + if filename is None: + print(json.dumps(data, cls=_AtEncoder, indent=indent)) + else: + with open(filename, "wt") as jsonfile: + json.dump(data, jsonfile, cls=_AtEncoder, indent=indent) + + +def load_json(filename: str, **kwargs) -> Lattice: + """Create a :py:class:`.Lattice` from a JSON file + + Parameters: + filename: Name of a JSON file + + Keyword Args: + *: All keywords update the lattice properties + + Returns: + lattice (Lattice): New :py:class:`.Lattice` object + + See Also: + :py:meth:`.Lattice.load` for a generic lattice-loading method. + """ + + def json_generator(params: dict[str, Any], fn): + + with open(params.setdefault("in_file", fn), "rt") as jsonfile: + data = json.load(jsonfile) + # Check the file signature - For later use + try: + atjson = data["atjson"] # noqa F841 + except KeyError: + atjson = 1 # noqa F841 + # Get elements + elements = data["elements"] + # Get lattice properties + try: + properties = data["properties"] + except KeyError: + properties = {} + particle_dict = properties.pop("particle", {}) + params.setdefault("particle", Particle(**particle_dict)) + + for k, v in properties.items(): + params.setdefault(k, v) + for idx, elem_dict in enumerate(elements): + yield Element.from_dict(elem_dict, index=idx, check=False) + + return Lattice(abspath(filename), iterator=json_generator, **kwargs) + + +register_format( + ".json", + load_json, + save_json, + descr="JSON representation of a python AT Lattice. See :py:func:`.load_json`.", +) diff --git a/pyat/at/load/matfile.py b/pyat/at/load/matfile.py index b58152d33..92d10666e 100644 --- a/pyat/at/load/matfile.py +++ b/pyat/at/load/matfile.py @@ -7,63 +7,71 @@ __all__ = ["load_mat", "save_mat", "load_m", "save_m", "load_var"] import sys +import os from os.path import abspath, basename, splitext from typing import Optional, Any from collections.abc import Sequence, Generator from warnings import warn -import numpy +import numpy as np import scipy.io +# imports necessary in 'globals()' for 'eval' +from numpy import array, uint8, NaN # noqa: F401 + from .allfiles import register_format -from .utils import element_from_dict, element_from_m, RingParam -from .utils import element_to_dict, element_to_m -from ..lattice import Element, Lattice, Filter +from .utils import split_ignoring_parentheses, RingParam, keep_elements +from .utils import _drop_attrs, _CLASS_MAP +from ..lattice import Element, Lattice, Particle, Filter from ..lattice import elements, AtWarning, params_filter, AtError +# Translation of RingParam attributes _m2p = { "FamName": "name", "Energy": "energy", "Periodicity": "periodicity", "Particle": "particle", - "cell_harmnumber": "cell_harmnumber", + "cell_harmnumber": "cell_harmnumber", # necessary: property + "beam_current": "beam_current", # necessary: property + "PassMethod": None, + "Length": None, + "cavpts": None, +} +_p2m = dict((v, k) for k, v in _m2p.items() if v is not None) +# Attribute to drop when writing a file +_p2m.update(_drop_attrs) + +# Python to Matlab type translation +_mattype_map = { + int: float, + np.ndarray: lambda attr: np.asanyarray(attr), + Particle: lambda attr: attr.to_dict(), } -_param_ignore = {"PassMethod", "Length", "cavpts"} +# Matlab constructor function +# Default: "".join(("at", element_class.__name__.lower())) +_mat_constructor = { + "Dipole": "atsbend", + "M66": "atM66", +} + -# Python to Matlab -_p2m = {"name", "energy", "periodicity", "particle", "cell_harmnumber", "beam_current"} +def _mat_encoder(v): + """type encoding for .mat files""" + return _mattype_map.get(type(v), lambda attr: attr)(v) -def matfile_generator( +def _matfile_generator( params: dict[str, Any], mat_file: str ) -> Generator[Element, None, None]: - """Run through Matlab cells and generate AT elements - - Parameters: - params: Lattice building parameters (see :py:class:`.Lattice`) - mat_file: File name - - The following keys in ``params`` are used: - - ============ =================== - **mat_key** name of the Matlab variable containing the lattice. - Default: Matlab variable name if there is only one, - otherwise 'RING' - **check** Skip the coherence tests - **quiet** Suppress the warning for non-standard classes - ============ =================== - - Yields: - elem (Element): new Elements - """ + """Run through Matlab cells and generate AT elements""" def mclean(data): - if data.dtype.type is numpy.str_: + if data.dtype.type is np.str_: # Convert strings in arrays back to strings. return str(data[0]) if data.size > 0 else "" elif data.size == 1: v = data[0, 0] - if issubclass(v.dtype.type, numpy.void): + if issubclass(v.dtype.type, np.void): # Object => Return a dict return {f: mclean(v[f]) for f in v.dtype.fields} else: @@ -71,17 +79,17 @@ def mclean(data): return v else: # Remove any surplus dimensions in arrays. - return numpy.squeeze(data) + return np.squeeze(data) # noinspection PyUnresolvedReferences - m = scipy.io.loadmat(params.setdefault("mat_file", mat_file)) + m = scipy.io.loadmat(params.setdefault("in_file", mat_file)) matvars = [varname for varname in m if not varname.startswith("__")] default_key = matvars[0] if (len(matvars) == 1) else "RING" - key = params.setdefault("mat_key", default_key) + key = params.setdefault("use", default_key) if key not in m.keys(): kok = [k for k in m.keys() if "__" not in k] raise AtError( - "Selected mat_key does not exist, please select in: {}".format(kok) + f"Selected '{key}' variable does not exist, please select in: {kok}" ) check = params.pop("check", True) quiet = params.pop("quiet", False) @@ -89,7 +97,7 @@ def mclean(data): for index, mat_elem in enumerate(cell_array): elem = mat_elem[0, 0] kwargs = {f: mclean(elem[f]) for f in elem.dtype.fields} - yield element_from_dict(kwargs, index=index, check=check, quiet=quiet) + yield Element.from_dict(kwargs, index=index, check=check, quiet=quiet) def ringparam_filter( @@ -132,8 +140,9 @@ def ringparam_filter( if isinstance(elem, RingParam): ringparams.append(elem) for k, v in elem.items(): - if k not in _param_ignore: - params.setdefault(_m2p.get(k, k), v) + k2 = _m2p.get(k, k) + if k2 is not None: + params.setdefault(k2, v) if keep_all: pars = vars(elem).copy() name = pars.pop("FamName") @@ -153,9 +162,10 @@ def load_mat(filename: str, **kwargs) -> Lattice: filename: Name of a '.mat' file Keyword Args: - mat_key (str): Name of the Matlab variable containing - the lattice. Default: Matlab variable name if there is only one, - otherwise 'RING' + use (str): Name of the Matlab variable containing + the lattice. Default: it there is a single variable, use it, otherwise + select 'RING' + mat_key (str): deprecated alias for *use* check (bool): Run the coherence tests. Default: :py:obj:`True` quiet (bool): Suppress the warning for non-standard @@ -178,42 +188,81 @@ def load_mat(filename: str, **kwargs) -> Lattice: :py:func:`.load_lattice` for a generic lattice-loading function. """ if "key" in kwargs: # process the deprecated 'key' keyword - kwargs.setdefault("mat_key", kwargs.pop("key")) + kwargs.setdefault("use", kwargs.pop("key")) + if "mat_key" in kwargs: # process the deprecated 'mat_key' keyword + kwargs.setdefault("use", kwargs.pop("mat_key")) return Lattice( ringparam_filter, - matfile_generator, + _matfile_generator, abspath(filename), iterator=params_filter, **kwargs, ) -def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]: - """Run through the lines of a Matlab m-file and generate AT elements +def _element_from_m(line: str) -> Element: + """Builds an :py:class:`.Element` from a line in an m-file Parameters: - params: Lattice building parameters (see :py:class:`.Lattice`) - m_file: File name + line: Matlab string representation of an :py:class:`.Element` - Yields: - elem (Element): new Elements + Returns: + elem (Element): new :py:class:`.Element` """ - with open(params.setdefault("m_file", m_file), "rt") as file: - _ = next(file) # Matlab function definition - _ = next(file) # Cell array opening - for lineno, line in enumerate(file): - if line.startswith("};"): - break - try: - elem = element_from_m(line) - except ValueError: - warn(AtWarning("Invalid line {0} skipped.".format(lineno))) - continue - except KeyError: - warn(AtWarning("Line {0}: Unknown class.".format(lineno))) - continue - else: - yield elem + + def argsplit(value): + return [a.strip() for a in split_ignoring_parentheses(value)] + + def makedir(mat_struct): + """Build directory from Matlab struct arguments""" + + def pairs(it): + while True: + try: + a = next(it) + except StopIteration: + break + yield eval(a), convert(next(it)) + + return dict(pairs(iter(mat_struct))) + + def makearray(mat_arr): + """Build numpy array for Matlab array syntax""" + + def arraystr(arr): + lns = arr.split(";") + rr = [arraystr(v) for v in lns] if len(lns) > 1 else lns[0].split() + return f"[{', '.join(rr)}]" + + return eval(f"array({arraystr(mat_arr)})") + + def convert(value): + """convert Matlab syntax to numpy syntax""" + if value.startswith("["): + result = makearray(value[1:-1]) + elif value.startswith("struct"): + result = makedir(argsplit(value[7:-1])) + else: + result = eval(value) + return result + + left = line.index("(") + right = line.rindex(")") + matcls = line[:left].strip()[2:] + cls = _CLASS_MAP[matcls] + arguments = argsplit(line[left + 1 : right]) + ll = len(cls._BUILD_ATTRIBUTES) + if ll < len(arguments) and arguments[ll].endswith("Pass'"): + arguments.insert(ll, "'PassMethod'") + args = [convert(v) for v in arguments[:ll]] + kwargs = makedir(arguments[ll:]) + if matcls == "rbend": + # the Matlab 'rbend' has no equivalent in PyAT. This adds parameters + # necessary for using the python sector bend + halfangle = 0.5 * args[2] + kwargs.setdefault("EntranceAngle", halfangle) + kwargs.setdefault("ExitAngle", halfangle) + return cls(*args, **kwargs) def load_m(filename: str, **kwargs) -> Lattice: @@ -240,6 +289,26 @@ def load_m(filename: str, **kwargs) -> Lattice: See Also: :py:func:`.load_lattice` for a generic lattice-loading function. """ + + def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]: + """Run through the lines of a Matlab m-file and generate AT elements""" + with open(params.setdefault("in_file", m_file), "rt") as file: + _ = next(file) # Matlab function definition + _ = next(file) # Cell array opening + for lineno, line in enumerate(file): + if line.startswith("};"): + break + try: + elem = _element_from_m(line) + except ValueError: + warn(AtWarning("Invalid line {0} skipped.".format(lineno))) + continue + except KeyError: + warn(AtWarning("Line {0}: Unknown class.".format(lineno))) + continue + else: + yield elem + return Lattice( ringparam_filter, mfile_generator, @@ -274,7 +343,7 @@ def load_var(matlat: Sequence[dict], **kwargs) -> Lattice: # noinspection PyUnusedLocal def var_generator(params, latt): for elem in latt: - yield element_from_dict(elem) + yield Element.from_dict(elem) return Lattice( ringparam_filter, var_generator, matlat, iterator=params_filter, **kwargs @@ -288,42 +357,100 @@ def required(rng): # Public lattice attributes params = dict((k, v) for k, v in vars(rng).items() if not k.startswith("_")) # Output the required attributes/properties - for k in _p2m: + for kp, km in _p2m.items(): try: - v = getattr(rng, k) + v = getattr(rng, kp) except AttributeError: pass else: - params.pop(k, None) - yield k, v + params.pop(kp, None) + if km is not None: + yield km, v # Output the remaining attributes yield from params.items() dct = dict(required(ring)) yield RingParam(**dct) - for elem in ring: - if not ( - isinstance(elem, elements.Marker) - and getattr(elem, "tag", None) == "RingParam" - ): - yield elem + yield from keep_elements(ring) -def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None: +def save_mat(ring: Lattice, filename: str, **kwargs) -> None: """Save a :py:class:`.Lattice` as a Matlab mat-file Parameters: - ring: Lattice description - filename: Name of the '.mat' file - mat_key (str): Name of the Matlab variable containing - the lattice. Default: ``'RING'`` + ring: Lattice description + filename: Name of the '.mat' file + + Keyword Args: + use (str): Name of the Matlab variable containing the lattice, Default: "RING" + mat_key (str): Deprecated, alias for *use* See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ - lring = tuple((element_to_dict(elem),) for elem in matlab_ring(ring)) - # noinspection PyUnresolvedReferences - scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True) + # Ensure the lattice is a Matlab column vector: list(list) + use = kwargs.pop("mat_key", "RING") # For backward compatibility + use = kwargs.pop("use", use) + lring = [[el.to_dict(encoder=_mat_encoder)] for el in matlab_ring(ring)] + scipy.io.savemat(filename, {use: lring}, long_field_names=True) + + +def _element_to_m(elem: Element) -> str: + """Builds the Matlab-evaluable string for an :py:class:`.Element` + + Parameters: + elem: :py:class:`.Element` + + Returns: + mstr (str): Matlab string representation of the + :py:class:`.Element` attributes + """ + + def convert(arg): + def convert_dict(pdir): + def scan(d): + for k, v in d.items(): + yield convert(k) + yield convert(v) + + return "struct({0})".format(", ".join(scan(pdir))) + + def convert_array(arr): + if arr.ndim > 1: + lns = (str(list(ln)).replace(",", "")[1:-1] for ln in arr) + return "".join(("[", "; ".join(lns), "]")) + elif arr.ndim > 0: + return str(list(arr)).replace(",", "") + else: + return str(arr) + + if isinstance(arg, np.ndarray): + return convert_array(arg) + elif isinstance(arg, dict): + return convert_dict(arg) + elif isinstance(arg, Particle): + return convert_dict(arg.to_dict()) + else: + return repr(arg) + + def m_name(elclass): + classname = elclass.__name__ + return _mat_constructor.get(classname, "".join(("at", classname.lower()))) + + attrs = dict(elem.items()) + # noinspection PyProtectedMember + args = [attrs.pop(k, getattr(elem, k)) for k in elem._BUILD_ATTRIBUTES] + defelem = elem.__class__(*args) + kwds = dict( + (k, v) + for k, v in attrs.items() + if not np.array_equal(v, getattr(defelem, k, None)) + ) + argstrs = [convert(arg) for arg in args] + if "PassMethod" in kwds: + argstrs.append(convert(kwds.pop("PassMethod"))) + argstrs += [", ".join((repr(k), convert(v))) for k, v in kwds.items()] + return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs)) def save_m(ring: Lattice, filename: Optional[str] = None) -> None: @@ -341,7 +468,7 @@ def save_m(ring: Lattice, filename: Optional[str] = None) -> None: def save(file): print("ring = {...", file=file) for elem in matlab_ring(ring): - print(element_to_m(elem), file=file) + print(_element_to_m(elem), file=file) print("};", file=file) if filename is None: @@ -354,5 +481,42 @@ def save(file): print("end", file=mfile) -register_format(".mat", load_mat, save_mat, descr="Matlab binary mat-file") -register_format(".m", load_m, save_m, descr="Matlab text m-file") +# Simulates the deprecated "mat_file" and "mat_key" attributes +def _mat_file(ring): + """.mat input file. Deprecated, use *in_file* instead.""" + try: + in_file = ring.in_file + except AttributeError: + raise AttributeError("'Lattice' object has no attribute 'mat_file'") + else: + _, ext = os.path.splitext(in_file) + if ext != ".mat": + raise AttributeError("'Lattice' object has no attribute 'mat_file'") + return in_file + + +def _mat_key(ring): + """selected Matlab variable. Deprecated, use *use* instead.""" + try: + mat_key = ring.use + except AttributeError: + raise AttributeError("'Lattice' object has no attribute 'mat_key'") + return mat_key + + +register_format( + ".mat", + load_mat, + save_mat, + descr="Matlab binary mat-file. See :py:func:`.load_mat`.", +) + +register_format( + ".m", + load_m, + save_m, + descr="Matlab text m-file. See :py:func:`.load_m`.", +) + +Lattice.mat_file = property(_mat_file, None, None) +Lattice.mat_key = property(_mat_key, None, None) diff --git a/pyat/at/load/reprfile.py b/pyat/at/load/reprfile.py index 045f47528..8a1fe7cc6 100644 --- a/pyat/at/load/reprfile.py +++ b/pyat/at/load/reprfile.py @@ -1,18 +1,41 @@ """Text representation of a python AT lattice with each element represented by its :py:func:`repr` string """ -from __future__ import print_function + +from __future__ import annotations + +__all__ = ["load_repr", "save_repr"] + import sys from os.path import abspath from typing import Optional -import numpy -from at.lattice import Lattice -from at.load import register_format -from at.load.utils import element_from_string + +import numpy as np + +# imports necessary in 'globals()' for 'eval' +from numpy import array, uint8, NaN # noqa: F401 + +from at.lattice import Lattice, Element + # imports necessary in' globals()' for 'eval' from at.lattice import Particle # noqa: F401 +from at.load import register_format +from at.load.utils import element_classes, keep_attributes, keep_elements + +# Map class names to Element classes +_CLASS_MAP = dict((cls.__name__, cls) for cls in element_classes()) -__all__ = ['load_repr', 'save_repr'] + +def _element_from_string(elem_string: str) -> Element: + """Builds an :py:class:`.Element` from its python :py:func:`repr` string + + Parameters: + elem_string: String representation of an :py:class:`.Element` + + Returns: + elem (Element): new :py:class:`.Element` + """ + return eval(elem_string, globals(), _CLASS_MAP) def load_repr(filename: str, **kwargs) -> Lattice: @@ -37,13 +60,14 @@ def load_repr(filename: str, **kwargs) -> Lattice: See Also: :py:func:`.load_lattice` for a generic lattice-loading function. """ + def elem_iterator(params, repr_file): - with open(params.setdefault('repr_file', repr_file), 'rt') as file: + with open(params.setdefault("in_file", repr_file), "rt") as file: # the 1st line is the dictionary of saved lattice parameters for k, v in eval(next(file)).items(): params.setdefault(k, v) for line in file: - yield element_from_string(line.strip()) + yield _element_from_string(line.strip()) return Lattice(abspath(filename), iterator=elem_iterator, **kwargs) @@ -59,25 +83,25 @@ def save_repr(ring: Lattice, filename: Optional[str] = None) -> None: See Also: :py:func:`.save_lattice` for a generic lattice-saving function. """ + def save(file): - # print(repr(dict((k, v) for k, v in vars(ring).items() - # if not k.startswith('_'))), file=file) - print(repr(ring.attrs), file=file) - for elem in ring: + print(repr(keep_attributes(ring)), file=file) + for elem in keep_elements(ring): print(repr(elem), file=file) - # Save the current options - opts = numpy.get_printoptions() # Set options to print the full representation of float variables - numpy.set_printoptions(formatter={'float_kind': repr}) - if filename is None: - save(sys.stdout) - else: - with open(filename, 'wt') as reprfile: - save(reprfile) - # Restore the current options - numpy.set_printoptions(**opts) - - -register_format('.repr', load_repr, save_repr, - descr='Text representation of a python AT Lattice') + with np.printoptions(formatter={"float_kind": repr}): + if filename is None: + save(sys.stdout) + else: + with open(filename, "wt") as reprfile: + save(reprfile) + + +register_format( + ".repr", + load_repr, + save_repr, + descr=("Text representation of a python AT Lattice. " + "See :py:func:`.load_repr`."), +) diff --git a/pyat/at/load/tracy.py b/pyat/at/load/tracy.py index 3630ed497..b82df1ceb 100644 --- a/pyat/at/load/tracy.py +++ b/pyat/at/load/tracy.py @@ -4,6 +4,7 @@ This parser is quite similar to the Elegant parser in elegant.py. """ + import logging as log from os.path import abspath import re @@ -21,7 +22,7 @@ from at.lattice import Lattice from at.load import register_format, utils -__all__ = ['load_tracy'] +__all__ = ["load_tracy"] def create_drift(name, params, variables): @@ -51,15 +52,13 @@ def create_sext(name, params, variables): def create_dipole(name, params, variables): length = parse_float(params.pop("l", 0), variables) - params["NumIntSteps"] = parse_float(params.pop("n", 10), - variables) + params["NumIntSteps"] = parse_float(params.pop("n", 10), variables) params["PassMethod"] = "BndMPoleSymplectic4Pass" - params["BendingAngle"] = (parse_float(params.pop("t"), - variables) / 180) * numpy.pi - params["EntranceAngle"] = (parse_float(params.pop("t1"), - variables) / 180) * numpy.pi - params["ExitAngle"] = (parse_float(params.pop("t2"), - variables) / 180) * numpy.pi + params["BendingAngle"] = (parse_float(params.pop("t"), variables) / 180) * numpy.pi + params["EntranceAngle"] = ( + parse_float(params.pop("t1"), variables) / 180 + ) * numpy.pi + params["ExitAngle"] = (parse_float(params.pop("t2"), variables) / 180) * numpy.pi # Tracy is encoding gap plus fringe int in the 'gap' field. # Since BndMPoleSymplectic4Pass only uses the product of FringeInt # and gap we can substitute the following. @@ -97,8 +96,7 @@ def create_cavity(name, params, variables): params["Phi"] = parse_float(params.pop("phi", 0), variables) harmonic_number = variables["harmonic_number"] energy = variables["energy"] - return RFCavity(name, length, voltage, frequency, - harmonic_number, energy, **params) + return RFCavity(name, length, voltage, frequency, harmonic_number, energy, **params) ELEMENT_MAP = { @@ -187,27 +185,40 @@ def evaluate(tokens): try: b1 = tokens.index("(") b2 = len(tokens) - 1 - tokens[::-1].index(")") - return evaluate(tokens[:b1] + [evaluate(tokens[b1 + 1:b2])] - + tokens[b2 + 1:]) + return evaluate( + tokens[:b1] + [evaluate(tokens[b1 + 1 : b2])] + tokens[b2 + 1 :] + ) except ValueError: # No open parentheses found. pass # Evaluate / and * from left to right. for i, token in enumerate(tokens[:-1]): if token == "/": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - / float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) / float(tokens[i + 1])] + + tokens[i + 2 :] + ) if token == "*": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - * float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) * float(tokens[i + 1])] + + tokens[i + 2 :] + ) # Evaluate + and - from left to right. for i, token in enumerate(tokens[:-1]): if token == "+": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - + float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) + float(tokens[i + 1])] + + tokens[i + 2 :] + ) if token == "-": - return evaluate(tokens[:i-1] + [float(tokens[i-1]) - - float(tokens[i+1])] + tokens[i+2:]) + return evaluate( + tokens[: i - 1] + + [float(tokens[i - 1]) - float(tokens[i + 1])] + + tokens[i + 2 :] + ) return evaluate(tokens) @@ -288,9 +299,7 @@ def expand_tracy(contents, lattice_key, harmonic_number): else: key, value = line.split(":") if value.split(",")[0].strip() in ELEMENT_MAP: - elements[key] = tracy_element_from_string(key, - value, - variables) + elements[key] = tracy_element_from_string(key, value, variables) else: chunk = parse_chunk(value, elements, chunks) chunks[key] = chunk @@ -351,13 +360,12 @@ def load_tracy(filename: str, **kwargs) -> Lattice: filename: Name of a Tracy file Keyword Args: - name (str): Name of the lattice. Default: taken from - the file. + use (str): Name of the variable containing the desired lattice. + Default: ``cell`` + name (str): Name of the lattice. Default: taken from the file. energy (float): Energy of the lattice [eV] - periodicity(int): Number of periods. Default: taken from the - elements, or 1 - *: All other keywords will be set as Lattice - attributes + periodicity(int): Number of periods. Default: taken from the elements, or 1 + *: All other keywords will be set as Lattice attributes Returns: lattice (Lattice): New :py:class:`.Lattice` object @@ -366,11 +374,8 @@ def load_tracy(filename: str, **kwargs) -> Lattice: :py:func:`.load_lattice` for a generic lattice-loading function. """ try: - harmonic_number = kwargs.pop("harmonic_number") - lattice_key = kwargs.pop("lattice_key", "cell") - def elem_iterator(params, tracy_file): - with open(params.setdefault("tracy_file", tracy_file)) as f: + with open(params.setdefault("in_file", tracy_file)) as f: contents = f.read() element_lines, energy = expand_tracy( contents, lattice_key, harmonic_number @@ -379,10 +384,13 @@ def elem_iterator(params, tracy_file): for line in element_lines: yield line + if "lattice_key" in kwargs: # process the deprecated 'lattice_key' keyword + kwargs.setdefault("use", kwargs.pop("lattice_key")) + harmonic_number = kwargs.pop("harmonic_number") + lattice_key = kwargs.pop("use", "cell") return Lattice(abspath(filename), iterator=elem_iterator, **kwargs) except Exception as e: - raise ValueError('Failed to load tracy ' - 'lattice {}: {}'.format(filename, e)) + raise ValueError("Failed to load tracy " "lattice {}: {}".format(filename, e)) -register_format(".lat", load_tracy, descr="Tracy format") +register_format(".lat", load_tracy, descr="Tracy format. See :py:func:`.load_tracy`.") diff --git a/pyat/at/load/utils.py b/pyat/at/load/utils.py index 446f618f6..3d0eccddf 100644 --- a/pyat/at/load/utils.py +++ b/pyat/at/load/utils.py @@ -4,36 +4,50 @@ from __future__ import annotations +__all__ = [ + "element_classes", + "element_from_dict", + "element_to_dict", + "find_class", + "keep_elements", + "keep_attributes", + "split_ignoring_parentheses", + "RingParam", +] + import collections import os import re import sysconfig -from typing import Optional +from typing import Any, Optional from warnings import warn +from collections.abc import Callable, Generator import numpy as np -# imports necessary in' globals()' for 'eval' -from numpy import array, uint8, NaN # noqa: F401 - from at import integrators from at.lattice import AtWarning -from at.lattice import CLASS_MAP, elements as elt -from at.lattice import Particle, Element +from at.lattice import elements as elt +from at.lattice import Lattice, Particle, Element, Marker from at.lattice import idtable_element _ext_suffix = sysconfig.get_config_var("EXT_SUFFIX") -_relativistic_particle = Particle() + + +def _no_encoder(v): + """type encoding for .mat files""" + return v def _particle(value) -> Particle: if isinstance(value, Particle): # Create from python: save_mat return value - else: + elif isinstance(value, dict): # Create from Matlab: load_mat - name = value.pop("name") - return Particle(name, **value) + return Particle(**value) + else: + return Particle(value) def _warn(index: int, message: str, elem_dict: dict) -> None: @@ -43,6 +57,20 @@ def _warn(index: int, message: str, elem_dict: dict) -> None: warn(AtWarning(warning), stacklevel=2) +def element_classes() -> frozenset[type[Element]]: + """Build a set of all Element subclasses""" + + # Misses class aliases (Bend, Matrix66) + def subclasses_recursive(cl): + direct = cl.__subclasses__() + indirect = [] + for subclass in direct: + indirect.extend(subclasses_recursive(subclass)) + return frozenset([cl] + direct + indirect) + + return subclasses_recursive(Element) + + class RingParam(elt.Element): """Private class for Matlab RingParam element @@ -58,23 +86,22 @@ class RingParam(elt.Element): elt.Element._conversions, Energy=float, Periodicity=int, Particle=_particle ) + # noinspection PyPep8Naming def __init__( self, - name: str, - energy: float, - periodicity: int = 1, - particle: Particle = _relativistic_particle, + FamName: str, + Energy: float, + Periodicity: int, **kwargs, ): - if not np.isnan(float(energy)): - kwargs.setdefault("Energy", energy) - kwargs.setdefault("Periodicity", periodicity) - kwargs.setdefault("Particle", particle) + if not np.isnan(float(Energy)): + kwargs["Energy"] = Energy kwargs.setdefault("PassMethod", "IdentityPass") - super(RingParam, self).__init__(name, **kwargs) + super(RingParam, self).__init__(FamName, Periodicity=Periodicity, **kwargs) _alias_map = { + "bend": elt.Dipole, "rbend": elt.Dipole, "sbend": elt.Dipole, "quad": elt.Quadrupole, @@ -84,15 +111,14 @@ def __init__( "ap": elt.Aperture, "ringparam": RingParam, "wig": elt.Wiggler, - "insertiondevicekickmap": idtable_element.InsertionDeviceKickMap, "matrix66": elt.M66, } - -# Matlab to Python class translation -_CLASS_MAP = dict((k.lower(), v) for k, v in CLASS_MAP.items()) +# Map class names to Element classes +_CLASS_MAP = dict((cls.__name__.lower(), cls) for cls in element_classes()) _CLASS_MAP.update(_alias_map) +# Maps passmethods to Element classes _PASS_MAP = { "BendLinearPass": elt.Dipole, "BndMPoleSymplectic4RadPass": elt.Dipole, @@ -110,36 +136,25 @@ def __init__( "GWigSymplecticPass": elt.Wiggler, } -# Matlab to Python attribute translation -_param_to_lattice = { - "Energy": "energy", - "Periodicity": "periodicity", - "FamName": "name", -} - -# Python to Matlab class translation -_matclass_map = { +# Maps python class name to Matlab class +# Default: element_class.__name__ +_mat_class = { "Dipole": "Bend", - "InsertionDeviceKickMap": "InsertionDeviceKickMap", "M66": "Matrix66", } -# Python to Matlab type translation -_mattype_map = { - int: float, - np.ndarray: lambda attr: np.asanyarray(attr), - Particle: lambda attr: attr.to_dict(), -} - -_class_to_matfunc = { - elt.Dipole: "atsbend", - elt.Bend: "atsbend", - elt.M66: "atM66", - idtable_element.InsertionDeviceKickMap: "atinsertiondevicekickmap", +# Lattice attributes which must be dropped when writing a file +_drop_attrs = { + "in_file": None, + "use": None, + "mat_key": None, + "mat_file": None, # Not used anymore... + "m_file": None, + "repr_file": None, } -def hasattrs(kwargs: dict, *attributes) -> bool: +def _hasattrs(kwargs: dict, *attributes) -> bool: """Checks the presence of keys in a :py:class:`dict` Returns :py:obj:`True` if any of the ``attributes`` is in ``kwargs`` @@ -159,10 +174,79 @@ def hasattrs(kwargs: dict, *attributes) -> bool: return False +def keep_attributes(ring: Lattice): + """Remove Lattice attributes which must not be saved on file""" + return dict( + (k, v) for k, v in ring.attrs.items() if _drop_attrs.get(k, k) is not None + ) + + +def keep_elements(ring: Lattice) -> Generator[Element, None, None]: + """Remove the 'RingParam' Marker""" + for elem in ring: + if not (isinstance(elem, Marker) and getattr(elem, "tag", None) == "RingParam"): + yield elem + + +def _from_contents(elem: dict) -> type[Element]: + """Deduce the element class from its contents""" + + def low_order(key): + polynom = np.array(elem[key], dtype=np.float64).reshape(-1) + try: + low = np.where(polynom != 0.0)[0][0] + except IndexError: + low = -1 + return low + + length = float(elem.get("Length", 0.0)) + pass_method = elem.get("PassMethod", "") + if _hasattrs( + elem, "FullGap", "FringeInt1", "FringeInt2", "gK", "EntranceAngle", "ExitAngle" + ): + return elt.Dipole + elif _hasattrs(elem, "Voltage", "Frequency", "HarmNumber", "PhaseLag", "TimeLag"): + return elt.RFCavity + elif _hasattrs(elem, "Periodicity"): + # noinspection PyProtectedMember + return RingParam + elif _hasattrs(elem, "Limits"): + return elt.Aperture + elif _hasattrs(elem, "M66"): + return elt.M66 + elif _hasattrs(elem, "K"): + return elt.Quadrupole + elif _hasattrs(elem, "PolynomB", "PolynomA"): + loworder = low_order("PolynomB") + if loworder == 1: + return elt.Quadrupole + elif loworder == 2: + return elt.Sextupole + elif loworder == 3: + return elt.Octupole + elif pass_method.startswith("StrMPoleSymplectic4") or (length > 0): + return elt.Multipole + else: + return elt.ThinMultipole + elif _hasattrs(elem, "KickAngle"): + return elt.Corrector + elif length > 0.0: + return elt.Drift + elif _hasattrs(elem, "GCR"): + return elt.Monitor + elif pass_method == "IdentityPass": + return elt.Marker + else: + return elt.Element + + def find_class( elem_dict: dict, quiet: bool = False, index: Optional[int] = None ) -> type(Element): - """Identify the class of an element from its attributes + """Deduce the class of an element from its attributes + + `find_class` looks first at the "Class" field, if existing. It then tries to deduce + the class from "FamName", from "PassMethod", and finally form the element contents. Args: elem_dict: The dictionary of keyword arguments passed to the @@ -174,88 +258,40 @@ def find_class( element_class: The guessed Class name """ - def low_order(key): - polynom = np.array(elem_dict[key], dtype=np.float64).reshape(-1) - try: - low = np.where(polynom != 0.0)[0][0] - except IndexError: - low = -1 - return low - - class_name = elem_dict.pop("Class", "") - try: - return _CLASS_MAP[class_name.lower()] - except KeyError: - if not quiet and class_name: - _warn(index, f"Class '{class_name}' does not exist.", elem_dict) - fam_name = elem_dict.get("FamName", "") - try: - return _CLASS_MAP[fam_name.lower()] - except KeyError: - pass_method = elem_dict.get("PassMethod", "") - if not quiet and not pass_method: - _warn(index, "No PassMethod provided.", elem_dict) - elif not quiet and not pass_method.endswith("Pass"): - message = ( - f"Invalid PassMethod '{pass_method}', " - "provided pass methods should end in 'Pass'." - ) - _warn(index, message, elem_dict) - class_from_pass = _PASS_MAP.get(pass_method) - if class_from_pass is not None: - return class_from_pass - else: - length = float(elem_dict.get("Length", 0.0)) - if hasattrs( - elem_dict, - "FullGap", - "FringeInt1", - "FringeInt2", - "gK", - "EntranceAngle", - "ExitAngle", - ): - return elt.Dipole - elif hasattrs( - elem_dict, - "Voltage", - "Frequency", - "HarmNumber", - "PhaseLag", - "TimeLag", - ): - return elt.RFCavity - elif hasattrs(elem_dict, "Periodicity"): - # noinspection PyProtectedMember - return RingParam - elif hasattrs(elem_dict, "Limits"): - return elt.Aperture - elif hasattrs(elem_dict, "M66"): - return elt.M66 - elif hasattrs(elem_dict, "K"): - return elt.Quadrupole - elif hasattrs(elem_dict, "PolynomB", "PolynomA"): - loworder = low_order("PolynomB") - if loworder == 1: - return elt.Quadrupole - elif loworder == 2: - return elt.Sextupole - elif loworder == 3: - return elt.Octupole - elif pass_method.startswith("StrMPoleSymplectic4") or (length > 0): - return elt.Multipole - else: - return elt.ThinMultipole - elif hasattrs(elem_dict, "KickAngle"): - return elt.Corrector - elif length > 0.0: - return elt.Drift - elif hasattrs(elem_dict, "GCR"): - return elt.Monitor - elif pass_method == "IdentityPass": - return elt.Marker - else: - return elt.Element + def check_class(clname): + if clname: + _warn(index, f"Class '{clname}' does not exist.", elem_dict) + + def check_pass(passm): + if not passm: + _warn(index, "No PassMethod provided.", elem_dict) + elif not passm.endswith("Pass"): + message = ( + f"Invalid PassMethod '{passm}': " + "provided pass methods should end in 'Pass'." + ) + _warn(index, message, elem_dict) + + class_name = elem_dict.pop("Class", "") # try from class name + cls = _CLASS_MAP.get(class_name.lower(), None) + if cls is not None: + return cls + elif not quiet: + check_class(class_name) + + elname = elem_dict.get("FamName", "") # try from element name + cls = _CLASS_MAP.get(elname.lower(), None) + if cls is not None: + return cls + + pass_method = elem_dict.get("PassMethod", "") # try from passmethod + cls = _PASS_MAP.get(pass_method, None) + if cls is not None: + return cls + elif not quiet: + check_pass(pass_method) + + return _from_contents(elem_dict) # look for contents def element_from_dict( @@ -326,172 +362,27 @@ def sanitise_class(index, cls, elem_dict): return element -def element_from_string(elem_string: str) -> Element: - """Builds an :py:class:`.Element` from its python :py:func:`repr` string - - Parameters: - elem_string: String representation of an :py:class:`.Element` - - Returns: - elem (Element): new :py:class:`.Element` - """ - return eval(elem_string, globals(), CLASS_MAP) - - -def element_from_m(line: str) -> Element: - """Builds an :py:class:`.Element` from a line in an m-file - - Parameters: - line: Matlab string representation of an :py:class:`.Element` - - Returns: - elem (Element): new :py:class:`.Element` - """ - - def argsplit(value): - return [a.strip() for a in split_ignoring_parentheses(value, ",")] - - def makedir(mat_struct): - """Build directory from Matlab struct arguments""" - - def pairs(it): - while True: - try: - a = next(it) - except StopIteration: - break - yield eval(a), convert(next(it)) - - return dict(pairs(iter(mat_struct))) - - def makearray(mat_arr): - """Build numpy array for Matlab array syntax""" - - def arraystr(arr): - lns = arr.split(";") - rr = [arraystr(v) for v in lns] if len(lns) > 1 else lns[0].split() - return "[{0}]".format(", ".join(rr)) - - return eval("array({0})".format(arraystr(mat_arr))) - - def convert(value): - """convert Matlab syntax to numpy syntax""" - if value.startswith("["): - result = makearray(value[1:-1]) - elif value.startswith("struct"): - result = makedir(argsplit(value[7:-1])) - else: - result = eval(value) - return result - - left = line.index("(") - right = line.rindex(")") - matcls = line[:left].strip()[2:] - cls = _CLASS_MAP[matcls] - arguments = argsplit(line[left + 1 : right]) - ll = len(cls._BUILD_ATTRIBUTES) - if ll < len(arguments) and arguments[ll].endswith("Pass'"): - arguments.insert(ll, "'PassMethod'") - args = [convert(v) for v in arguments[:ll]] - kwargs = makedir(arguments[ll:]) - if matcls == "rbend": - # the Matlab 'rbend' has no equivalent in PyAT. This adds parameters - # necessary for using the python sector bend - halfangle = 0.5 * args[2] - kwargs.setdefault("EntranceAngle", halfangle) - kwargs.setdefault("ExitAngle", halfangle) - return cls(*args, **kwargs) - - -def element_to_dict(elem: Element) -> dict: - """Builds the Matlab structure of an :py:class:`.Element` +def element_to_dict(elem: Element, encoder: Callable[[Any], Any] = _no_encoder) -> dict: + """Convert a :py:class:`.Element` to a :py:class:`dict` Parameters: elem: :py:class:`.Element` + encoder: data converter Returns: dct (dict): Dictionary of :py:class:`.Element` attributes """ - dct = dict( - (k, _mattype_map.get(type(v), lambda attr: attr)(v)) for k, v in elem.items() - ) + dct = dict((k, encoder(v)) for k, v in elem.items()) class_name = elem.__class__.__name__ - dct["Class"] = _matclass_map.get(class_name, class_name) + dct["Class"] = _mat_class.get(class_name, class_name) return dct -def element_to_m(elem: Element) -> str: - """Builds the Matlab-evaluable string for an :py:class:`.Element` - - Parameters: - elem: :py:class:`.Element` +def split_ignoring_parentheses(string: str, delimiter: str = ",") -> list[str]: + """Split a string while keeping parenthesized expressions intact - Returns: - mstr (str): Matlab string representation of the - :py:class:`.Element` attributes + Example: "l=0,hom(4,0.0,0)" -> ["l=0", "hom(4,0.0,0)"] """ - - def convert(arg): - def convert_dict(pdir): - def scan(d): - for k, v in d.items(): - yield convert(k) - yield convert(v) - - return "struct({0})".format(", ".join(scan(pdir))) - - def convert_array(arr): - if arr.ndim > 1: - lns = (str(list(ln)).replace(",", "")[1:-1] for ln in arr) - return "".join(("[", "; ".join(lns), "]")) - elif arr.ndim > 0: - return str(list(arr)).replace(",", "") - else: - return str(arr) - - if isinstance(arg, np.ndarray): - return convert_array(arg) - elif isinstance(arg, dict): - return convert_dict(arg) - elif isinstance(arg, Particle): - return convert_dict(arg.to_dict()) - else: - return repr(arg) - - def m_name(elclass): - stdname = "".join(("at", elclass.__name__.lower())) - return _class_to_matfunc.get(elclass, stdname) - - attrs = dict(elem.items()) - # noinspection PyProtectedMember - args = [attrs.pop(k, getattr(elem, k)) for k in elem._BUILD_ATTRIBUTES] - defelem = elem.__class__(*args) - kwds = dict( - (k, v) - for k, v in attrs.items() - if not np.array_equal(v, getattr(defelem, k, None)) - ) - argstrs = [convert(arg) for arg in args] - if "PassMethod" in kwds: - argstrs.append(convert(kwds.pop("PassMethod"))) - argstrs += [", ".join((repr(k), convert(v))) for k, v in kwds.items()] - return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs)) - - -# Kept for compatibility but should be deprecated: - - -CLASS_MAPPING = dict((key, cls.__name__) for (key, cls) in _CLASS_MAP.items()) - -PASS_MAPPING = dict((key, cls.__name__) for (key, cls) in _PASS_MAP.items()) - - -def find_class_name(elem_dict, quiet=False): - """Derive the class name of an Element from its attributes""" - return find_class(elem_dict, quiet=quiet).__name__ - - -def split_ignoring_parentheses(string, delimiter): placeholder = "placeholder" substituted = string[:] matches = collections.deque(re.finditer("\\(.*?\\)", string)) @@ -507,3 +398,7 @@ def split_ignoring_parentheses(string, delimiter): assert not matches return replaced_parts + + +Element.from_dict = staticmethod(element_from_dict) +Element.to_dict = element_to_dict diff --git a/pyat/test/test_basic_elements.py b/pyat/test/test_basic_elements.py index 94ec5c349..380f4b424 100644 --- a/pyat/test/test_basic_elements.py +++ b/pyat/test/test_basic_elements.py @@ -15,8 +15,8 @@ def test_data_checks(): def test_element_string_ordering(): d = elements.Drift('D0', 1, attr=numpy.array(0)) - assert d.__str__() == ("Drift:\n\tFamName : D0\n\tLength : 1.0\n" - "\tPassMethod : DriftPass\n\tattr : 0") + assert d.__str__() == ("Drift:\n FamName: D0\n Length: 1.0\n" + " PassMethod: DriftPass\n attr: 0") assert d.__repr__() == "Drift('D0', 1.0, attr=array(0))" diff --git a/pyat/test/test_load_save.py b/pyat/test/test_load_save.py new file mode 100644 index 000000000..b45b1553a --- /dev/null +++ b/pyat/test/test_load_save.py @@ -0,0 +1,36 @@ +import os +from tempfile import mktemp + +import pytest +from numpy.testing import assert_equal + +from at.lattice import Lattice + + +@pytest.mark.parametrize("lattice", ["dba_lattice", "hmba_lattice"]) +@pytest.mark.parametrize( + "suffix, options", + [ + (".m", {}), + (".repr", {}), + (".mat", {"use": "abcd"}), + (".json", {}), + ], +) +def test_m(request, lattice, suffix, options): + ring0 = request.getfixturevalue(lattice) + fname = mktemp(suffix=suffix) + + # Create a new .m or .repr file + ring0.save(fname, **options) + + # load the new file + ring1 = Lattice.load(fname, **options) + + # Check that we get back the original lattice + el1, rg1, _ = ring0.linopt6() + el2, rg2, _ = ring1.linopt6() + assert_equal(rg1.tune, rg2.tune) + assert_equal(rg1.chromaticity, rg2.chromaticity) + + os.unlink(fname) diff --git a/pyat/test/test_load_utils.py b/pyat/test/test_load_utils.py index 470328756..78c5de551 100644 --- a/pyat/test/test_load_utils.py +++ b/pyat/test/test_load_utils.py @@ -5,8 +5,10 @@ from at import AtWarning from at.lattice import Lattice, elements, params_filter, no_filter from at.load.utils import find_class, element_from_dict +# noinspection PyProtectedMember from at.load.utils import _CLASS_MAP, _PASS_MAP from at.load.utils import RingParam, split_ignoring_parentheses +# noinspection PyProtectedMember from at.load.matfile import ringparam_filter @@ -66,8 +68,8 @@ def test_inconsistent_energy_values_warns_correctly(): def test_more_than_one_RingParam_in_ring_raises_warning(): - p1 = RingParam('rp1', 5.e+6) - p2 = RingParam('rp2', 3.e+6) + p1 = RingParam('rp1', 5.e+6, 1) + p2 = RingParam('rp2', 3.e+6, 1) with pytest.warns(AtWarning): params = _matlab_scanner([p1, p2]) assert params['_energy'] == 5.e+6