Skip to content

Commit

Permalink
kernel: make attributes public
Browse files Browse the repository at this point in the history
  • Loading branch information
Gattocrucco committed Sep 1, 2023
1 parent d799260 commit 46286b5
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 62 deletions.
28 changes: 14 additions & 14 deletions src/lsqfitgp/_Kernel/_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def add(tcls, self, other):
The other kernel.
"""
core = self._core
core = self.core
if _util.is_numerical_scalar(other):
newcore = lambda x, y, **kw: core(x, y, **kw) + other
elif isinstance(other, CrossKernel):
other = other._core
other = other.core
newcore = lambda x, y, **kw: core(x, y, **kw) + other(x, y, **kw)
else:
return NotImplemented
return self._clone(_core=newcore)
return self._clone(core=newcore)

@CrossKernel.register_algop
def mul(tcls, self, other):
Expand All @@ -71,15 +71,15 @@ def mul(tcls, self, other):
The other kernel.
"""
core = self._core
core = self.core
if _util.is_numerical_scalar(other):
newcore = lambda x, y, **kw: core(x, y, **kw) * other
elif isinstance(other, CrossKernel):
other = other._core
other = other.core
newcore = lambda x, y, **kw: core(x, y, **kw) * other(x, y, **kw)
else:
return NotImplemented
return self._clone(_core=newcore)
return self._clone(core=newcore)

@CrossKernel.register_algop
def pow(tcls, self, *, exponent):
Expand All @@ -97,9 +97,9 @@ def pow(tcls, self, *, exponent):
"""
if _util.is_nonnegative_integer_scalar(exponent):
core = self._core
core = self.core
newcore = lambda x, y, **kw: core(x, y, **kw) ** exponent
return self._clone(_core=newcore)
return self._clone(core=newcore)
else:
return NotImplemented

Expand All @@ -123,9 +123,9 @@ def rpow(tcls, self, *, base):
"""
if _util.is_scalar_cond_trueontracer(base, lambda x: x >= 1):
core = self._core
core = self.core
newcore = lambda x, y, **kw: base ** core(x, y, **kw)
return self._clone(_core=newcore)
return self._clone(core=newcore)
else:
return NotImplemented

Expand Down Expand Up @@ -156,19 +156,19 @@ def rpow(tcls, self, *, base):
def affine_add(tcls, self, other):
newself = AffineSpan.super_transf('add', self, other)
if _util.is_numerical_scalar(other):
dynkw = dict(self._dynkw)
dynkw = dict(self.dynkw)
dynkw['offset'] = dynkw['offset'] + other
return newself._clone(self.__class__, _dynkw=dynkw)
return newself._clone(self.__class__, dynkw=dynkw)
else:
return newself

@functools.partial(AffineSpan.register_algop, transfname='mul')
def affine_mul(tcls, self, other):
newself = AffineSpan.super_transf('mul', self, other)
if _util.is_numerical_scalar(other):
dynkw = dict(self._dynkw)
dynkw = dict(self.dynkw)
dynkw['offset'] = other * dynkw['offset']
dynkw['ampl'] = other * dynkw['ampl']
return newself._clone(self.__class__, _dynkw=dynkw)
return newself._clone(self.__class__, dynkw=dynkw)
else:
return newself
81 changes: 52 additions & 29 deletions src/lsqfitgp/_Kernel/_crosskernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,21 @@ class CrossKernel:
If specified, apply ``.batch(batchbytes)`` to the kernel.
dynkw : dict, optional
Additional keyword arguments passed to `core` that can be modified
by transformations.
**kw :
by transformations. Deleted by transformations by default.
**initkw :
Additional keyword arguments passed to `core` that can be read but not
changed by transformations.
Attributes
----------
initkw : dict
The `initkw` argument.
dynkw : dict
The `dynkw` argument, or a modification of it if the object has been
transformed.
core : callable
The `core` argument partially evaluated on `initkw`, or another
function wrapping it if the object has been transformed.
Methods
-------
Expand All @@ -108,6 +119,7 @@ class CrossKernel:
has_transf
list_transf
super_transf
make_linop_family
See also
--------
Expand All @@ -121,7 +133,20 @@ class CrossKernel:
"""

__slots__ = '_kw', '_dynkw', '_core'
__slots__ = '_initkw', '_dynkw', '_core'
# only __new__ and _clone shall access these attributes

@property
def initkw(self):
return types.MappingProxyType(self._initkw)

@property
def dynkw(self):
return types.MappingProxyType(self._dynkw)

@property
def core(self):
return self._core

def __new__(cls, core, *,
scale=None,
Expand All @@ -132,13 +157,13 @@ def __new__(cls, core, *,
forcekron=False,
batchbytes=None,
dynkw={},
**kw,
**initkw,
):
self = super().__new__(cls)

self._kw = kw
self._dynkw = dynkw
self._core = lambda x, y, **dynkw: core(x, y, **kw, **dynkw)
self._initkw = initkw
self._dynkw = dict(dynkw)
self._core = lambda x, y, **dynkw: core(x, y, **initkw, **dynkw)

if forcekron:
self = self.transf('forcekron')
Expand All @@ -152,7 +177,7 @@ def __new__(cls, core, *,
}
for transfname, arg in linop_args.items():
if callable(arg):
arg = arg(**kw)
arg = arg(**initkw)
if isinstance(arg, tuple):
self = self.linop(transfname, *arg)
else:
Expand All @@ -167,19 +192,17 @@ def __call__(self, x, y):
x = _array.asarray(x)
y = _array.asarray(y)
shape = _array.broadcast(x, y).shape
result = self._core(x, y, **self._dynkw)
result = self.core(x, y, **self.dynkw)
assert isinstance(result, (numpy.ndarray, jnp.number, jnp.ndarray))
assert jnp.issubdtype(result.dtype, jnp.number), result.dtype
assert result.shape == shape, (result.shape, shape)
return result

def _clone(self, cls=None, **attrs):
def _clone(self, cls=None, *, initkw=None, dynkw=None, core=None):
newself = object.__new__(self.__class__ if cls is None else cls)
newself._kw = self._kw
newself._dynkw = {}
newself._core = self._core
for k, v in attrs.items():
setattr(newself, k, v)
newself._initkw = self._initkw if initkw is None else dict(initkw)
newself._dynkw = {} if dynkw is None else dict(dynkw)
newself._core = self._core if core is None else core
return newself

class _side(enum.Enum):
Expand All @@ -198,7 +221,7 @@ def _nary(cls, op, kernels, side):
else: # pragma: no cover
raise KeyError(side)

cores = [k._core for k in kernels]
cores = [k.core for k in kernels]
def core(x, y, **kw):
wrapped = [wrapper(c, x, y, **kw) for c in cores]
transformed = op(*wrapped)
Expand Down Expand Up @@ -233,10 +256,10 @@ def __rpow__(self, other):

def _swap(self):
""" permute the arguments """
core = self._core
core = self.core
return self._clone(
__class__,
_core=lambda x, y, **kw: core(y, x, **kw),
core=lambda x, y, **kw: core(y, x, **kw),
)

# TODO make _swap a transf inherited by CrossIsotropicKernel => messes
Expand All @@ -263,8 +286,8 @@ def batch(self, maxnbytes):
batched_kernel : CrossKernel
The same kernel but with batched computations.
"""
core = _jaxext.batchufunc(self._core, maxnbytes=maxnbytes)
return self._clone(_core=core)
core = _jaxext.batchufunc(self.core, maxnbytes=maxnbytes)
return self._clone(core=core)

@classmethod
def _crossmro(cls):
Expand Down Expand Up @@ -842,9 +865,9 @@ def register_corelinop(cls, corefunc, transfname=None, doc=None, argparser=None)
"""
@functools.wraps(corefunc)
def op(_, self, arg1, arg2, *operands):
cores = (o._core for o in operands)
core = corefunc(self._core, arg1, arg2, *cores)
return self._clone(_core=core)
cores = (o.core for o in operands)
core = corefunc(self.core, arg1, arg2, *cores)
return self._clone(core=core)
cls.register_linop(op, transfname, doc, argparser)
return corefunc

Expand Down Expand Up @@ -950,7 +973,7 @@ def classes():
cls.register_transf(func, transfname, doc, cls._algopmarker)
return op

# TODO delete _kw (also in linop) if there's more than one kernel
# TODO delete initkw (also in linop) if there's more than one kernel
# operand or if the class changed?

# TODO consider adding an option domains=callable, returns list of
Expand Down Expand Up @@ -995,14 +1018,14 @@ def register_ufuncalgop(cls, ufunc, transfname=None, doc=None):
@functools.wraps(ufunc)
def op(_, self, *operands, **kw):
cores = tuple(
o._core if isinstance(o, __class__)
o.core if isinstance(o, __class__)
else lambda x, y: o
for o in (self, *operands)
)
def core(x, y, **kw):
values = (core(x, y, **kw) for core in cores)
return ufunc(*values, **kw)
return self._clone(_core=core)
return self._clone(core=core)
cls.register_algop(op, transfname, doc)
return ufunc

Expand Down Expand Up @@ -1102,7 +1125,7 @@ def __new__(cls, *args, **kw):

# function to produce the arguments to the transformed objects
def makekw(self, arg1, arg2):
kw = dict(dynkw=self._dynkw, **self._kw)
kw = dict(dynkw=self.dynkw, **self.initkw)
if argnames is not None:
if arg1 is not None:
kw = dict(**kw, **{argnames[0]: arg1})
Expand Down Expand Up @@ -1240,8 +1263,8 @@ def __subclasshook__(cls, sub):

# TODO I could do separately AffineLeft, AffineRight, AffineOut, and make
# this a subclass of those three. AffineOut would also allow keeping the
# class when adding two objects without keyword arguments in _kw and _dynkw
# beyond those managed by Affine.
# class when adding two objects without keyword arguments in initkw and
# dynkw beyond those managed by Affine.

# TODO when I reimplement transformations as methods, make AffineSpan not
# a subclass of CrossKernel. Right now I have to to avoid routing around
Expand Down
2 changes: 1 addition & 1 deletion src/lsqfitgp/_Kernel/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __new__(cls, **kw):
warnings.warn(f'overriding init argument(s) '
f'{shared_keys} of kernel {name}')
self = super(newclass, cls).__new__(cls, core, **kwargs)
if isinstance(self, base) and set(kw).issubset(self._kw):
if isinstance(self, base) and set(kw).issubset(self.initkw):
self = self._clone(cls)
return self

Expand Down
4 changes: 2 additions & 2 deletions src/lsqfitgp/_Kernel/_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def forcekron(tcls, self):
"""

core = self._core
core = self.core
newcore = lambda x, y, **kw: _util.prod_recurse_dtype(core, x, y, **kw)
return self._clone(tcls, _core=newcore)
return self._clone(tcls, core=newcore)

_crosskernel.Kernel = Kernel
8 changes: 4 additions & 4 deletions src/lsqfitgp/_Kernel/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,17 @@ def newcore(x, y, **kw):

@functools.partial(AffineSpan.register_linop, transfname='loc')
def affine_loc(tcls, self, xloc, yloc):
dynkw = dict(self._dynkw)
dynkw = dict(self.dynkw)
newself = tcls.super_transf('loc', self, xloc, yloc)
ploc = dynkw['loc']
pscale = dynkw['scale']
dynkw['loc'] = (ploc[0] + xloc * pscale[0], ploc[1] + yloc * pscale[1])
return newself._clone(self.__class__, _dynkw=dynkw)
return newself._clone(self.__class__, dynkw=dynkw)

@functools.partial(AffineSpan.register_linop, transfname='scale')
def affine_scale(tcls, self, xscale, yscale):
dynkw = dict(self._dynkw)
dynkw = dict(self.dynkw)
newself = tcls.super_transf('scale', self, xscale, yscale)
pscale = dynkw['scale']
dynkw['scale'] = (pscale[0] * xscale, pscale[1] * yscale)
return newself._clone(self.__class__, _dynkw=dynkw)
return newself._clone(self.__class__, dynkw=dynkw)
4 changes: 2 additions & 2 deletions src/lsqfitgp/_kernels/_zeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def CrossZetaFourier(k, y, *, nu):
return jnp.where(k > 0, jnp.where(odd, jnp.sin(arg), jnp.cos(arg)) / denom, 0)

@functools.partial(Zeta.register_linop, argparser=lambda do: do if do else None)
def fourier(self, dox, doy):
def fourier(_, self, dox, doy):
r"""
Compute the Fourier series transform of the function.
Expand All @@ -118,7 +118,7 @@ def fourier(self, dox, doy):
"""

nu = self._kw['nu']
nu = self.initkw['nu']

if dox and doy:
return ZetaFourier(nu=nu)
Expand Down
20 changes: 10 additions & 10 deletions tests/kernels/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# You should have received a copy of the GNU General Public License
# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.

""" Test the generic kernel machinery. This file covers at 100% the _Kernel
""" Test the generic kernel machinery. This file shall cover at 100% the _Kernel
submodule. """

import sys
Expand Down Expand Up @@ -403,7 +403,7 @@ class A(lgp.CrossKernel): pass
a = A(constcore)
b = a.linop('ciao', 1, 2)
assert a is b
assert a._core is b._core
assert a.core is b.core

def test_class_goes_to_cross_parent(self, constcore, idtransf):
class A(lgp.CrossKernel): pass
Expand Down Expand Up @@ -935,16 +935,16 @@ class A(lgp._Kernel.AffineSpan, lgp.CrossKernel): pass
a = a.linop('loc', 17, 19)

# compare accumulated coefficients with manual calculation
assert a._dynkw['offset'] == (2 * 3 + 5) * 7
assert a._dynkw['ampl'] == 3 * 7
assert a._dynkw['loc'] == (2 * (5 + 11 * 17), 3 * (7 + 13 * 19))
assert a._dynkw['scale'] == (2 * 11, 3 * 13)
assert a.dynkw['offset'] == (2 * 3 + 5) * 7
assert a.dynkw['ampl'] == 3 * 7
assert a.dynkw['loc'] == (2 * (5 + 11 * 17), 3 * (7 + 13 * 19))
assert a.dynkw['scale'] == (2 * 11, 3 * 13)

# compare result with specification of coefficients
c1 = a(x, y)
c2 = a._dynkw['offset'] + a._dynkw['ampl'] * a0(
(x - a._dynkw['loc'][0]) / a._dynkw['scale'][0],
(y - a._dynkw['loc'][1]) / a._dynkw['scale'][1],
c2 = a.dynkw['offset'] + a.dynkw['ampl'] * a0(
(x - a.dynkw['loc'][0]) / a.dynkw['scale'][0],
(y - a.dynkw['loc'][1]) / a.dynkw['scale'][1],
)
util.assert_allclose(c1, c2)

Expand All @@ -957,7 +957,7 @@ def test_callable_arg(constcore, rng):
def test_init_kw_preserved(constcore):
kernel = lgp.Kernel(constcore, cippa=4)
def check(k):
assert k._kw['cippa'] == 4
assert k.initkw['cippa'] == 4
check(kernel._swap())
check(kernel.linop('loc', 1, 2))
check(kernel.transf('forcekron'))
Expand Down

0 comments on commit 46286b5

Please sign in to comment.