Skip to content

Commit

Permalink
Merge pull request #394 from tomasstolker/wavelets
Browse files Browse the repository at this point in the history
Additional test cases for the wavelet denoising module
  • Loading branch information
Tomas Stolker committed Nov 4, 2019
2 parents f10d9fe + 594a3ed commit 952b6c6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 33 deletions.
6 changes: 3 additions & 3 deletions pynpoint/processing/timedenoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self,

# check if wavelet is supported
if wavelet not in supported:
raise ValueError('DWT supports only ' + str(supported) + ' as input wavelet.')
raise ValueError(f'DWT supports only {supported} as input wavelet.')

self.m_wavelet = wavelet

Expand Down Expand Up @@ -177,7 +177,7 @@ def denoise_line_in_time(signal_in: np.ndarray) -> np.ndarray:
Parameters
----------
signal_in :
signal_in : numpy.ndarray
1D input signal.
Returns
Expand Down Expand Up @@ -219,7 +219,7 @@ def denoise_line_in_time(signal_in: np.ndarray) -> np.ndarray:
Parameters
----------
signal_in :
signal_in : numpy.ndarray
1D input signal.
Returns
Expand Down
130 changes: 100 additions & 30 deletions tests/test_processing/test_timedenoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pynpoint.core.pypeline import Pypeline
from pynpoint.readwrite.fitsreading import FitsReadingModule
from pynpoint.processing.resizing import AddLinesModule
from pynpoint.processing.timedenoising import CwtWaveletConfiguration, DwtWaveletConfiguration, \
WaveletTimeDenoisingModule, TimeNormalizationModule
from pynpoint.util.tests import create_config, remove_test_data, create_star_data
Expand Down Expand Up @@ -40,13 +41,13 @@ def teardown_class(self):

def test_read_data(self):

read = FitsReadingModule(name_in='read',
image_tag='images',
input_dir=self.test_dir+'images',
overwrite=True,
check=True)
module = FitsReadingModule(name_in='read',
image_tag='images',
input_dir=self.test_dir+'images',
overwrite=True,
check=True)

self.pipeline.add_module(read)
self.pipeline.add_module(module)
self.pipeline.run_module('read')

data = self.pipeline.get_data('images')
Expand All @@ -66,15 +67,15 @@ def test_wavelet_denoising_cwt_dog(self):
assert not cwt_config.m_keep_mean
assert np.allclose(cwt_config.m_resolution, 0.5, rtol=limit, atol=0.)

wavelet_cwt = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_cwt_dog',
image_in_tag='images',
image_out_tag='wavelet_cwt_dog',
padding='zero',
median_filter=True,
threshold_function='soft')
module = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_cwt_dog',
image_in_tag='images',
image_out_tag='wavelet_cwt_dog',
padding='zero',
median_filter=True,
threshold_function='soft')

self.pipeline.add_module(wavelet_cwt)
self.pipeline.add_module(module)
self.pipeline.run_module('wavelet_cwt_dog')

data = self.pipeline.get_data('wavelet_cwt_dog')
Expand Down Expand Up @@ -106,37 +107,40 @@ def test_wavelet_denoising_cwt_morlet(self):
assert not cwt_config.m_keep_mean
assert np.allclose(cwt_config.m_resolution, 0.5, rtol=limit, atol=0.)

wavelet_cwt = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_cwt_morlet',
image_in_tag='images',
image_out_tag='wavelet_cwt_morlet',
padding='mirror',
median_filter=False,
threshold_function='hard')
module = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_cwt_morlet',
image_in_tag='images',
image_out_tag='wavelet_cwt_morlet',
padding='mirror',
median_filter=False,
threshold_function='hard')

self.pipeline.add_module(wavelet_cwt)
self.pipeline.add_module(module)
self.pipeline.run_module('wavelet_cwt_morlet')

data = self.pipeline.get_data('wavelet_cwt_morlet')
assert np.allclose(data[0, 10, 10], 0.09805577173716859, rtol=limit, atol=0.)
assert np.allclose(np.mean(data), 0.0025019409784314286, rtol=limit, atol=0.)
assert data.shape == (40, 20, 20)

data = self.pipeline.get_attribute('wavelet_cwt_morlet', 'NFRAMES', static=False)
assert np.allclose(data, [10, 10, 10, 10], rtol=limit, atol=0.)

def test_wavelet_denoising_dwt(self):

dwt_config = DwtWaveletConfiguration(wavelet='db8')

assert dwt_config.m_wavelet == 'db8'

wavelet_dwt = WaveletTimeDenoisingModule(wavelet_configuration=dwt_config,
name_in='wavelet_dwt',
image_in_tag='images',
image_out_tag='wavelet_dwt',
padding='zero',
median_filter=True,
threshold_function='soft')
module = WaveletTimeDenoisingModule(wavelet_configuration=dwt_config,
name_in='wavelet_dwt',
image_in_tag='images',
image_out_tag='wavelet_dwt',
padding='zero',
median_filter=True,
threshold_function='soft')

self.pipeline.add_module(wavelet_dwt)
self.pipeline.add_module(module)
self.pipeline.run_module('wavelet_dwt')

data = self.pipeline.get_data('wavelet_dwt')
Expand All @@ -157,3 +161,69 @@ def test_time_normalization(self):
assert np.allclose(data[0, 10, 10], 0.09793500165714215, rtol=limit, atol=0.)
assert np.allclose(np.mean(data), 0.0024483409033199985, rtol=limit, atol=0.)
assert data.shape == (40, 20, 20)

def test_wavelet_denoising_odd_size(self):

module = AddLinesModule(name_in='add',
image_in_tag='images',
image_out_tag='images_odd',
lines=(1, 0, 1, 0))

self.pipeline.add_module(module)
self.pipeline.run_module('add')

data = self.pipeline.get_data('images_odd')
assert np.allclose(data[0, 10, 10], 0.05294085050174391, rtol=limit, atol=0.)
assert np.allclose(np.mean(data), 0.002269413609192613, rtol=limit, atol=0.)
assert data.shape == (40, 21, 21)

cwt_config = CwtWaveletConfiguration(wavelet='dog',
wavelet_order=2,
keep_mean=False,
resolution=0.5)

assert cwt_config.m_wavelet == 'dog'
assert np.allclose(cwt_config.m_wavelet_order, 2, rtol=limit, atol=0.)
assert not cwt_config.m_keep_mean
assert np.allclose(cwt_config.m_resolution, 0.5, rtol=limit, atol=0.)

module = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_odd_1',
image_in_tag='images_odd',
image_out_tag='wavelet_odd_1',
padding='zero',
median_filter=True,
threshold_function='soft')

self.pipeline.add_module(module)
self.pipeline.run_module('wavelet_odd_1')

data = self.pipeline.get_data('wavelet_odd_1')
assert np.allclose(data[0, 10, 10], 0.0529782051386938, rtol=limit, atol=0.)
assert np.allclose(np.mean(data), 0.0022694631406801565, rtol=limit, atol=0.)
assert data.shape == (40, 21, 21)

module = WaveletTimeDenoisingModule(wavelet_configuration=cwt_config,
name_in='wavelet_odd_2',
image_in_tag='images_odd',
image_out_tag='wavelet_odd_2',
padding='mirror',
median_filter=True,
threshold_function='soft')

self.pipeline.add_module(module)
self.pipeline.run_module('wavelet_odd_2')

data = self.pipeline.get_data('wavelet_odd_2')
assert np.allclose(data[0, 10, 10], 0.05297146283932275, rtol=limit, atol=0.)
assert np.allclose(np.mean(data), 0.0022694809842930034, rtol=limit, atol=0.)
assert data.shape == (40, 21, 21)

data = self.pipeline.get_attribute('images', 'NFRAMES', static=False)
assert np.allclose(data, [10, 10, 10, 10], rtol=limit, atol=0.)

data = self.pipeline.get_attribute('wavelet_odd_1', 'NFRAMES', static=False)
assert np.allclose(data, [10, 10, 10, 10], rtol=limit, atol=0.)

data = self.pipeline.get_attribute('wavelet_odd_2', 'NFRAMES', static=False)
assert np.allclose(data, [10, 10, 10, 10], rtol=limit, atol=0.)

0 comments on commit 952b6c6

Please sign in to comment.