diff --git a/pipt/misc_tools/wavelet_tools.py b/pipt/misc_tools/wavelet_tools.py index a907cb6..ac206de 100644 --- a/pipt/misc_tools/wavelet_tools.py +++ b/pipt/misc_tools/wavelet_tools.py @@ -16,7 +16,6 @@ class SparseRepresentation: def __init__(self, options): # options: dim, actnum, level, wname, colored_noise, threshold_rule, th_mult, # use_hard_th, keep_ca, inactive_value - # dim must be given as (nz,ny,nx) self.options = options self.num_grid = np.prod(self.options['dim']) @@ -30,7 +29,7 @@ def __init__(self, options): self.ca_leading_index = None self.ca_leading_coeff = None - # Function to doing image compression. If the function is called without threshold, then the leading indices must + # Function for image compression. If the function is called without threshold, then the leading indices must # be defined in the class. Typically, this is done by running the compression on true data with a given threshold. def compress(self, data, th_mult=None): if ('inactive_value' not in self.options) or (self.options['inactive_value'] is None): @@ -43,13 +42,10 @@ def compress(self, data, th_mult=None): if 'min_noise' not in self.options: self.options['min_noise'] = 1.0e-9 signal = signal.reshape(self.options['dim'], order=self.options['order']) - # get the signal back into its original shape (nx,ny,nz) - signal = signal.transpose((2, 1, 0)) - # pywt throws a warning in case of single-dimentional entries in the shape of the signal. - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - wdec = pywt.wavedecn( - signal, self.options['wname'], 'symmetric', int(self.options['level'])) + + # Wavelet decomposition + wdec = pywt.wavedecn(signal, self.options['wname'], 'symmetric', int(self.options['level'])) + wdec_rec = deepcopy(wdec) # Perform thresholding if the threshold is given as input. @@ -65,7 +61,12 @@ def compress(self, data, th_mult=None): # Initialize # Note: the keys below are organized the same way as in Matlab. - keys = ['daa', 'ada', 'dda', 'aad', 'dad', 'add', 'ddd'] + if signal.ndim == 3: + keys = ['daa', 'ada', 'dda', 'aad', 'dad', 'add', 'ddd'] + details = 'ddd' + elif signal.ndim == 2: + keys = ['da', 'ad', 'dd'] + details = 'dd' if true_data: for level in range(0, int(self.options['level'])+1): num_subband = 1 if level == 0 else len(keys) @@ -85,7 +86,7 @@ def compress(self, data, th_mult=None): # In the white noise case estimated std is based on the high (hhh) subband only if true_data and not self.options['colored_noise']: - subband_hhh = wdec[-1]['ddd'].flatten() + subband_hhh = wdec[-1][details].flatten() est_noise_level = np.median( np.abs(subband_hhh - np.median(subband_hhh))) / 0.6745 # estimated noise std est_noise_level = np.maximum(est_noise_level, self.options['min_noise']) @@ -224,10 +225,10 @@ def reconstruct(self, wdec_rec): print('No signal to reconstruct') sys.exit(1) + # reconstruct from wavelet coefficients data_rec = pywt.waverecn(wdec_rec, self.options['wname'], 'symmetric') - data_rec = data_rec.transpose((2, 1, 0)) # flip the axes - dim = self.options['dim'] - data_rec = data_rec[0:dim[0], 0:dim[1], 0:dim[2]] # severe issure here + data_rec = data_rec[tuple(slice(0, s) for s in self.options['dim'])] + data_rec = data_rec.flatten(order=self.options['order']) data_rec = data_rec[self.options['mask']]