Skip to content

Commit

Permalink
Added test_error.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lfarv committed Dec 20, 2022
1 parent b88232f commit 629c0ff
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
20 changes: 13 additions & 7 deletions pyat/at/errors/error_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
__all__ = ['find_orbit_err', 'get_optics_err', 'lattice_pass_err',
'assign_errors', 'enable_errors']

_BPM_ATTRS = ('BPMGain', 'BPMOffset', 'BPMTilt')
_ERR_ATTRS = ('PolynomBErr', 'PolynomAErr', 'ShiftErr', 'RotationErr')
_SEL_ARGS = ('all', 'PolynomAIndex', 'PolynomBIndex')
_BPM_ATTRS = {'BPMGain': (2,), 'BPMOffset': (2,), 'BPMTilt': (1,)}
_ERR_ATTRS = {'PolynomBErr': None, 'PolynomAErr': None, 'ShiftErr': (2,),
'RotationErr': None}
_ALL_ATTRS = dict(**_BPM_ATTRS, **_ERR_ATTRS)

_SEL_ARGS = tuple(_ERR_ATTRS.keys()) + \
('all', 'PolynomAIndex', 'PolynomBIndex')


def _truncated_randn(truncation=None, **kwargs):
Expand Down Expand Up @@ -125,7 +129,7 @@ def assign_errors(ring: Lattice, refpts: Refpts,
:py:func:`enable_errors`, :py:func:`get_optics_err`
"""
elements = ring[refpts]
for attr in _BPM_ATTRS + _ERR_ATTRS:
for attr, sz in _ALL_ATTRS.items():
val = kwargs.pop(attr, None)
if val is not None:
if isinstance(val, tuple):
Expand All @@ -135,7 +139,9 @@ def assign_errors(ring: Lattice, refpts: Refpts,
else:
rand = np.atleast_2d(val)
syst = np.zeros(rand.shape)
rv = _truncated_randn(size=(len(elements), rand.shape[-1]),
if sz is None:
sz = (rand.shape[-1],)
rv = _truncated_randn(size=((len(elements),) + sz),
truncation=truncation, random_state=seed)
try:
vals = syst + rv*rand
Expand All @@ -159,7 +165,7 @@ def _apply_bpm_orbit_error(ring, refpts, orbit):
if hasattr(e, 'BPMOffset'):
o6[:, [0, 2]] += e.BPMOffset
if hasattr(e, 'BPMTilt'):
o6[:, [0, 2]] = _rotmat(e.BPMTilt) @ o6[:, [0, 2]]
o6[:, [0, 2]] = o6[:, [0, 2]] @ _rotmat(e.BPMTilt).T
if hasattr(e, 'BPMGain'):
o6[:, [0, 2]] *= e.BPMGain

Expand Down Expand Up @@ -282,7 +288,7 @@ def get_mean_std_err(ring: Lattice, key, attr, index=0):

def _sort_flags(kwargs):
errargs = {}
for key in _ERR_ATTRS + _SEL_ARGS:
for key in _SEL_ARGS:
if key in kwargs:
errargs[key] = kwargs.pop(key)
return kwargs, errargs
Expand Down
27 changes: 12 additions & 15 deletions pyat/at/lattice/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@ def _array66(value):
return _array(value, shape=(6, 6))


def _array2(value):
return _array(value, shape=(2,))


def _resize(value, shape=(3,), dtype=numpy.float64):
if not numpy.all(value.shape == shape):
value = value.copy()
value.resize(shape)
return _array(value, shape=shape, dtype=dtype)


def _broadcast(value, shape=(2,), dtype=numpy.float64):
v = numpy.broadcast_to(value, shape)
return _array(v, shape=shape, dtype=dtype)


def _nop(value):
return value

Expand Down Expand Up @@ -261,15 +260,15 @@ class Element(object):
T1=lambda v: _array(v, (6,)),
T2=lambda v: _array(v, (6,)),
RApertures=lambda v: _array(v, (4,)),
EApertures=lambda v: _array(v, (2,)),
KickAngle=lambda v: _array(v, (2,)),
EApertures=_array2,
KickAngle=_array2,
PolynomB=_array, PolynomA=_array,
PolynomBErr=_array, PolynomAErr=_array,
BendingAngle=float,
MaxOrder=int, NumIntSteps=int,
Energy=float,
BPMGain=_broadcast,
BPMOffset=_broadcast,
BPMGain=_array2,
BPMOffset=_array2,
BPMTilt=float,
ShiftErr=_not_allowed,
RotationErr=_not_allowed,
Expand Down Expand Up @@ -586,7 +585,7 @@ class ThinMultipole(Element):
_BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['PolynomA',
'PolynomB']
_conversions = dict(Element._conversions,
ShiftErr=_broadcast,
ShiftErr=_array2,
RotationErr=_resize)

def __init__(self, family_name: str, poly_a, poly_b, **kwargs):
Expand Down Expand Up @@ -645,11 +644,6 @@ def __setattr__(self, key, value):

super(ThinMultipole, self).__setattr__(key, value)

@property
def strength(self):
order = getattr(self, 'DefaultOrder', None)
return None if order is None else self.PolynomB[order]


class Multipole(_Radiative, LongElement, ThinMultipole):
"""Multipole element"""
Expand Down Expand Up @@ -679,6 +673,7 @@ def __init__(self, family_name: str, length: float, poly_a, poly_b,
super(Multipole, self).__init__(family_name, length,
poly_a, poly_b, **kwargs)

# noinspection PyUnresolvedReferences
def is_compatible(self, other) -> bool:
if super().is_compatible(other) and \
self.MaxOrder == other.MaxOrder:
Expand Down Expand Up @@ -793,6 +788,7 @@ def _part(self, fr, sumfr):
pp.ExitAngle = 0.0
return pp

# noinspection PyUnresolvedReferences,PyTypeChecker
def is_compatible(self, other) -> bool:
def invrho(dip: Dipole):
return dip.BendingAngle / dip.Length
Expand Down Expand Up @@ -940,6 +936,7 @@ def _part(self, fr, sumfr):
pp.Voltage = fr * self.Voltage
return pp

# noinspection PyUnresolvedReferences
def is_compatible(self, other) -> bool:
return (super().is_compatible(other) and
self.Frequency == other.Frequency and
Expand Down
52 changes: 52 additions & 0 deletions pyat/test/test_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import pytest
from at import checktype, Monitor
import numpy as np
from numpy.testing import assert_allclose as assert_close


@pytest.mark.parametrize('offset', ((0.001, 0.0), ([0.001, -0.001], 0.0)))
@pytest.mark.parametrize('gain', (([1.1, 0.9], 0.0), (1.05, 0.0)))
@pytest.mark.parametrize('tilt', ((-0.002, 0.0),))
def test_systematic_bpm_errors(hmba_lattice, offset, gain, tilt):
ring = hmba_lattice.copy()
bpms = ring.get_cells(checktype(Monitor))
ring.assign_errors(bpms, BPMOffset=offset, BPMGain=gain, BPMTilt=tilt)
bpmoff = np.vstack([el.BPMOffset for el in ring.select(bpms)])
bpmgain = np.vstack([el.BPMGain for el in ring.select(bpms)])
bpmtilt = np.array([el.BPMTilt for el in ring.select(bpms)])

# Check that all values are correctly assigned
assert np.all(bpmoff == offset[0])
assert np.all(bpmgain == gain[0])
assert np.all(bpmtilt == tilt[0])


def test_random_bpm_errors(hmba_lattice):

def _rotmat(theta):
cs = np.cos(theta)
sn = np.sin(theta)
return np.array([[cs, sn], [-sn, cs]])

ring = hmba_lattice.copy()
bpms = ring.get_cells(checktype(Monitor))

offset = (0.001, [0.001, 0.002])
gain = ([1.1, 0.9], 0.01)
tilt = 0.001

ring.assign_errors(bpms, BPMOffset=offset, BPMGain=gain, BPMTilt=tilt)

xyorbit0 = ring.find_orbit(bpms)[1][:, [0, 2]]
xyorbit = ring.find_orbit_err(bpms, all=False)[1][:, [0, 2]]

bpmoff = np.vstack([el.BPMOffset for el in ring.select(bpms)])
bpmgain = np.vstack([el.BPMGain for el in ring.select(bpms)])
bpmtilt = np.array([el.BPMTilt for el in ring.select(bpms)])

# Build all rotation matrices
rotm = np.stack([_rotmat(v).T for v in bpmtilt], axis=0)
# reshape offset,
of = np.reshape(xyorbit0+bpmoff, (-1, 1, 2))
expected = np.squeeze(of @ rotm) * bpmgain
assert_close(xyorbit, expected, rtol=0, atol=0)

0 comments on commit 629c0ff

Please sign in to comment.