-
Notifications
You must be signed in to change notification settings - Fork 0
/
_base.py
104 lines (84 loc) · 3.32 KB
/
_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# lsqfitgp/_GP/_base.py
#
# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
#
# This file is part of lsqfitgp.
#
# lsqfitgp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# lsqfitgp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
import functools
import jax
from jax import numpy as jnp
from .. import _jaxext
from .. import _utils
class GPBase:
def __init__(self, *, checkfinite=True, checklin=True):
self._checkfinite = bool(checkfinite)
self._checklin = bool(checklin)
def _clone(self):
newself = object.__new__(self.__class__)
newself._checkfinite = self._checkfinite
newself._checklin = self._checklin
return newself
class _SingletonMeta(type):
def __repr__(cls):
return cls.__name__
class _Singleton(metaclass=_SingletonMeta):
def __new__(cls):
raise NotImplementedError(f"{cls.__name__} can not be instantiated")
class DefaultProcess(_Singleton):
""" Key of the default process. """
pass
def _checklinear(self, func, inshapes, elementwise=False):
# Make input arrays.
rkey = jax.random.PRNGKey(202206091600)
inp = []
for shape in inshapes:
rkey, subkey = jax.random.split(rkey)
inp.append(jax.random.normal(subkey, shape))
# Put zeros into the arrays to check they are preserved.
if elementwise:
shape = jnp.broadcast_shapes(*inshapes)
rkey, subkey = jax.random.split(rkey)
zeros = jax.random.bernoulli(subkey, 0.5, shape)
for i, a in enumerate(inp):
inp[i] = a.at[zeros].set(0)
# Compute JVP and check it is identical to the function itself.
with _jaxext.skipifabstract():
out0, out1 = jax.jvp(func, inp, inp)
if out1.dtype == jax.float0:
cond = jnp.allclose(out0, 0)
else:
cond = jnp.allclose(out0, out1)
if not cond:
raise RuntimeError('the transformation is not linear')
# Check that the function is elementwise.
if elementwise:
if out0.shape != shape or not (jnp.allclose(out0[zeros], 0) and jnp.allclose(out1[zeros], 0)):
raise RuntimeError('the transformation is not elementwise')
def newself(meth):
""" Decorator to create a new GP object and pass it to the method. """
@functools.wraps(meth)
def newmeth(self, *args, **kw):
self = self._clone()
meth(self, *args, **kw)
return self
# append return value description to docstring
doctail = """\
Returns
-------
gp : GP
A new GP object with the applied modifications.
"""
newmeth.__doc__ = _utils.append_to_docstring(meth.__doc__, doctail)
return newmeth