Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions pipt/misc_tools/wavelet_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class SparseRepresentation:
def __init__(self, options):
# options: dim, actnum, level, wname, colored_noise, threshold_rule, th_mult,
# use_hard_th, keep_ca, inactive_value
# dim must be given as (nz,ny,nx)
self.options = options
self.num_grid = np.prod(self.options['dim'])

Expand All @@ -30,7 +29,7 @@ def __init__(self, options):
self.ca_leading_index = None
self.ca_leading_coeff = None

# Function to doing image compression. If the function is called without threshold, then the leading indices must
# Function for image compression. If the function is called without threshold, then the leading indices must
# be defined in the class. Typically, this is done by running the compression on true data with a given threshold.
def compress(self, data, th_mult=None):
if ('inactive_value' not in self.options) or (self.options['inactive_value'] is None):
Expand All @@ -43,13 +42,10 @@ def compress(self, data, th_mult=None):
if 'min_noise' not in self.options:
self.options['min_noise'] = 1.0e-9
signal = signal.reshape(self.options['dim'], order=self.options['order'])
# get the signal back into its original shape (nx,ny,nz)
signal = signal.transpose((2, 1, 0))
# pywt throws a warning in case of single-dimentional entries in the shape of the signal.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
wdec = pywt.wavedecn(
signal, self.options['wname'], 'symmetric', int(self.options['level']))

# Wavelet decomposition
wdec = pywt.wavedecn(signal, self.options['wname'], 'symmetric', int(self.options['level']))

wdec_rec = deepcopy(wdec)

# Perform thresholding if the threshold is given as input.
Expand All @@ -65,7 +61,12 @@ def compress(self, data, th_mult=None):

# Initialize
# Note: the keys below are organized the same way as in Matlab.
keys = ['daa', 'ada', 'dda', 'aad', 'dad', 'add', 'ddd']
if signal.ndim == 3:
keys = ['daa', 'ada', 'dda', 'aad', 'dad', 'add', 'ddd']
details = 'ddd'
elif signal.ndim == 2:
keys = ['da', 'ad', 'dd']
details = 'dd'
if true_data:
for level in range(0, int(self.options['level'])+1):
num_subband = 1 if level == 0 else len(keys)
Expand All @@ -85,7 +86,7 @@ def compress(self, data, th_mult=None):

# In the white noise case estimated std is based on the high (hhh) subband only
if true_data and not self.options['colored_noise']:
subband_hhh = wdec[-1]['ddd'].flatten()
subband_hhh = wdec[-1][details].flatten()
est_noise_level = np.median(
np.abs(subband_hhh - np.median(subband_hhh))) / 0.6745 # estimated noise std
est_noise_level = np.maximum(est_noise_level, self.options['min_noise'])
Expand Down Expand Up @@ -224,10 +225,10 @@ def reconstruct(self, wdec_rec):
print('No signal to reconstruct')
sys.exit(1)

# reconstruct from wavelet coefficients
data_rec = pywt.waverecn(wdec_rec, self.options['wname'], 'symmetric')
data_rec = data_rec.transpose((2, 1, 0)) # flip the axes
dim = self.options['dim']
data_rec = data_rec[0:dim[0], 0:dim[1], 0:dim[2]] # severe issure here
data_rec = data_rec[tuple(slice(0, s) for s in self.options['dim'])]

data_rec = data_rec.flatten(order=self.options['order'])
data_rec = data_rec[self.options['mask']]

Expand Down
Loading