Skip to content

Commit

Permalink
Merge pull request #167 from CUQI-DTU/sprint16_KL_projection
Browse files Browse the repository at this point in the history
Sprint16 kl projection
  • Loading branch information
amal-ghamdi committed Jan 25, 2023
2 parents cd282e9 + cc24278 commit 16e3924
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
7 changes: 7 additions & 0 deletions cuqi/_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
""" This module defines messages (errors, warnings, etc) that are used throughout CUQIpy. The values of the following variables are not meant to be changed by the user."""

_disable_warning_msg = lambda module_name: "To disable "+\
"warnings for a given module or library, "+\
"you can use the method `warnings.filterwarnings`,"+\
" e.g.: warnings.filterwarnings(action='ignore', "+\
f"module=r'"+module_name+"')."
41 changes: 39 additions & 2 deletions cuqi/geometry/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import operator
from functools import reduce
import warnings
from cuqi._messages import _disable_warning_msg

class Geometry(ABC):
"""A class that represents the geometry of the range, domain, observation, or other sets.
Expand Down Expand Up @@ -731,8 +732,44 @@ def par2fun(self, p):
return real

def fun2par(self, funvals):
"""The function to parameter map used to map function values back to parameters, if available."""
raise NotImplementedError("fun2par not implemented. ")
"""The function to parameter map used to map function values back to
parameters. In this class (the `KLExpansion`), `fun2par` projects the
function on the KL expansion coefficient space. Hence this is not
always the inverse of `par2fun` but it is the closest estimation of the
function on the KL expansion coefficient space."""

# Check that the input is of the correct shape
if len(funvals) != self.fun_dim:
raise ValueError(
"Input array funvals must have length {}".format(self.fun_dim))

warnings.warn(
f"fun2par for {self.__class__} is a projection on "
+ "the KL expansion coefficients space where only "
+ f"the first self.num_modes={self.num_modes} "
+ "coefficients are returned. "
+ _disable_warning_msg("cuqi.geometry")
+ "\n"
)

# par2fun scales the input parameters then applies the inverse
# discrete sine transform of type 2 (IDST-II). Here we apply the
# inverse of the par2fun map and truncate the expansion
# coefficients to the number of modes.
# This includes applying the discrete sine transform of type 2 (DST-II)
# to the function values to obtain the expansion coefficients.
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.fftpack.dst.html

# Note that here we have a scaling by 2*self.fun_dim that does
# not correspond to the scaling in par2fun. This is needed
# because it is not accounted for in the scipy.fftpack implementation.
# However, if we use, for example, scipy.fft instead of scipy.fftpack,
# then this scaling is not needed.

p = dst(funvals*2)[:self.par_dim]\
*self.normalizer/(self.coefs*2*self.fun_dim)

return p

class KLExpansion_Full(Continuous1D):
'''
Expand Down
27 changes: 27 additions & 0 deletions tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,30 @@ def test_create_CustomKL_geometry():
geom.trunc_term==trunc_term


def test_KLExpansion_projection():
"""Check KLExpansion geometry projection performed by the method fun2par)"""
# Set up a KLExpansion geometry
num_modes = 95
N = 100
L = 1.0
grid = np.linspace(0,1,N)

geom = cuqi.geometry.KLExpansion(grid, num_modes=num_modes,
decay_rate=1.5,
normalizer=12.0)

# Create a signal
signal =1/30*(1-np.cos(2*np.pi*(L-grid)/(L)))\
+1/30*np.exp(-2*(10*(grid-0.5))**2)+\
1/30*np.exp(-2*(10*(grid-0.8))**2)

# Project signal to the KL basis and back
p = geom.fun2par(signal)
assert(len(p) == num_modes)

signal_proj = geom.par2fun(p)
assert(len(signal_proj) == N)

# Check that the projection is accurate
rel_err = np.linalg.norm(signal-signal_proj)/np.linalg.norm(signal)
assert np.isclose(rel_err, 0.0, atol=1e-5)

0 comments on commit 16e3924

Please sign in to comment.