Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sprint16 kl projection #167

Merged
merged 11 commits into from
Jan 25, 2023
7 changes: 7 additions & 0 deletions cuqi/_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
""" This module defines messages (errors, warnings, etc) that are used throughout CUQIpy. The values of the following variables are not meant to be changed by the user."""

_disable_warning_msg = lambda module_name: "To disable "+\
"warnings for a given module or library, "+\
"you can use the method `warnings.filterwarnings`,"+\
" e.g.: warnings.filterwarnings(action='ignore', "+\
f"module=r'"+module_name+"')."
24 changes: 23 additions & 1 deletion cuqi/geometry/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import operator
from functools import reduce
import warnings
from cuqi._messages import _disable_warning_msg

class Geometry(ABC):
"""A class that represents the geometry of the range, domain, observation, or other sets.
Expand Down Expand Up @@ -732,7 +733,28 @@ def par2fun(self, p):

def fun2par(self, funvals):
"""The function to parameter map used to map function values back to parameters, if available."""
jakobsj marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("fun2par not implemented. ")
# Check that the input is of the correct shape
if len(funvals) != self.fun_dim:
raise ValueError(
"Input array funvals must have length {}".format(self.fun_dim))

warnings.warn(
f"fun2par for {self.__class__} is a projection on "
+ "the KL expansion coefficients space where only "
+ f"the first self.num_modes={self.num_modes} "
+ "coefficients are returned. "
+ _disable_warning_msg("cuqi.geometry")
+ "\n"
)


# Note, the scaling by 2*self.fun_dim is not needed if scipy.
jakobsj marked this conversation as resolved.
Show resolved Hide resolved
# fft.dst and scipy.fft.idst are used in fun2par and par2fun,
# instead of using scipy.fftpack.dst and scipy.fftpack.idst.
p = dst(funvals*2)[:self.par_dim]\
*self.normalizer/(self.coefs*2*self.fun_dim)

return p

class KLExpansion_Full(Continuous1D):
'''
Expand Down
27 changes: 27 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,30 @@ def test_create_CustomKL_geometry():
geom.trunc_term==trunc_term


def test_KLExpansion_projection():
"""Check KLExpansion geometry projection performed by the method fun2par)"""
# Set up a KLExpansion geometry
num_modes = 95
N = 100
L = 1.0
grid = np.linspace(0,1,N)

geom = cuqi.geometry.KLExpansion(grid, num_modes=num_modes,
decay_rate=1.5,
normalizer=12.0)

# Create a signal
signal =1/30*(1-np.cos(2*np.pi*(L-grid)/(L)))\
+1/30*np.exp(-2*(10*(grid-0.5))**2)+\
1/30*np.exp(-2*(10*(grid-0.8))**2)

# Project signal to the KL basis and back
p = geom.fun2par(signal)
assert(len(p) == num_modes)

signal_proj = geom.par2fun(p)
assert(len(signal_proj) == N)

# Check that the projection is accurate
rel_err = np.linalg.norm(signal-signal_proj)/np.linalg.norm(signal)
assert np.isclose(rel_err, 0.0, atol=1e-5)
jakobsj marked this conversation as resolved.
Show resolved Hide resolved