Skip to content

Commit

Permalink
Merge pull request #104 from kwohlfahrt/ndim
Browse files Browse the repository at this point in the history
Faster idwtn/dwtn
  • Loading branch information
aaren committed Aug 14, 2015
2 parents fde6b5a + 13bd58d commit e3bab45
Show file tree
Hide file tree
Showing 9 changed files with 500 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ matrix:
- python: 2.6
env:
- OPTIMIZE=-OO
- NUMPYSPEC="numpy==1.6.2"
- NUMPYSPEC="numpy==1.7.2"
- python: 2.7
env:
- NUMPYSPEC="--upgrade git+git://github.com/numpy/numpy.git@v1.9.1"
Expand Down
71 changes: 14 additions & 57 deletions pywt/multidim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from ._pywt import Wavelet, MODES
from ._pywt import dwt, idwt, swt, downcoef, upcoef
from ._pywt import dwt, idwt, swt, downcoef, upcoef, dwt_axis, idwt_axis


def dwt2(data, wavelet, mode='sym'):
Expand Down Expand Up @@ -225,30 +225,25 @@ def dwtn(data, wavelet, mode='sym'):
"""
data = np.asarray(data)
dim = data.ndim
if dim < 1:
ndim = data.ndim

if data.dtype == np.dtype('object'):
raise TypeError("Input must be a numeric array-like")
if ndim < 1:
raise ValueError("Input data must be at least 1D")
coeffs = [('', data)]

def _downcoef(data, wavelet, mode, type):
"""Adapts pywt.downcoef call for apply_along_axis"""
return downcoef(type, data, wavelet, mode, level=1)

for axis in range(dim):
for axis in range(ndim):
new_coeffs = []
for subband, x in coeffs:
new_coeffs.extend([
(subband + 'a', np.apply_along_axis(_downcoef, axis, x,
wavelet, mode, 'a')),
(subband + 'd', np.apply_along_axis(_downcoef, axis, x,
wavelet, mode, 'd'))])

cA, cD = dwt_axis(x, wavelet, mode, axis)
new_coeffs.extend([(subband + 'a', cA),
(subband + 'd', cD)])
coeffs = new_coeffs

return dict(coeffs)


def idwtn(coeffs, wavelet, mode='sym', take=None):
def idwtn(coeffs, wavelet, mode='sym'):
"""
Single-level n-dimensional Discrete Wavelet Transform.
Expand All @@ -261,14 +256,7 @@ def idwtn(coeffs, wavelet, mode='sym', take=None):
Wavelet to use
mode : str, optional
Signal extension mode used in the decomposition,
see MODES (default: 'sym'). Overridden by `take`.
take : int or iterable of int or None, optional
Number of values to take from the center of the idwtn for each axis.
If 0, the entire reverse transformation will be used, including
parts generated from padding in the forward transform.
If None (default), will be calculated from `mode` to be the size of the
original data, rounded up to the nearest multiple of 2.
Passed to `upcoef`.
see MODES (default: 'sym').
Returns
-------
Expand All @@ -294,46 +282,15 @@ def idwtn(coeffs, wavelet, mode='sym', take=None):
if any(s != coeff_shape for s in coeff_shapes):
raise ValueError("`coeffs` must all be of equal size (or None)")

if take is not None:
try:
takes = list(islice(take, dims))
takes.reverse()
except TypeError:
takes = repeat(take, dims)
else:
# As in src/common.c
if mode == MODES.per:
takes = [2*s for s in reversed(coeff_shape)]
else:
takes = [2*s - wavelet.rec_len + 2 for s in reversed(coeff_shape)]

def _upcoef(coeffs, wavelet, take, type):
"""Adapts pywt.upcoef call for apply_along_axis"""
return upcoef(type, coeffs, wavelet, level=1, take=take)

for axis, take in zip(reversed(range(dims)), takes):
for axis in reversed(range(dims)):
new_coeffs = {}
new_keys = [''.join(coeff) for coeff in product('ad', repeat=axis)]

for key in new_keys:
L = coeffs.get(key + 'a')
H = coeffs.get(key + 'd')

if L is not None:
L = np.apply_along_axis(_upcoef, axis, L, wavelet, take, 'a')

if H is not None:
H = np.apply_along_axis(_upcoef, axis, H, wavelet, take, 'd')

if H is None and L is None:
new_coeffs[key] = None
elif H is None:
new_coeffs[key] = L
elif L is None:
new_coeffs[key] = H
else:
new_coeffs[key] = L + H

new_coeffs[key] = idwt_axis(L, H, wavelet, mode, axis)
coeffs = new_coeffs

return coeffs['']
Expand Down
123 changes: 118 additions & 5 deletions pywt/src/_pywt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,119 @@ def _dwt(np.ndarray[data_t, ndim=1] data, object wavelet, object mode='sym'):
return (cA, cD)


cpdef dwt_axis(np.ndarray data, object wavelet, object mode='sym', unsigned int axis=0):
cdef Wavelet w = c_wavelet_from_object(wavelet)
cdef common.MODE _mode = _try_mode(mode)
cdef common.ArrayInfo data_info, output_info
cdef np.ndarray cD, cA
cdef size_t[::1] output_shape

data = data.astype(_check_dtype(data), copy=False)

output_shape = (<size_t [:data.ndim]> <size_t *> data.shape).copy()
output_shape[axis] = common.dwt_buffer_length(data.shape[axis], w.dec_len, _mode)

cA = np.empty(output_shape, data.dtype)
cD = np.empty(output_shape, data.dtype)

data_info.ndim = data.ndim
data_info.strides = <index_t *> data.strides
data_info.shape = <size_t *> data.shape

output_info.ndim = cA.ndim
output_info.strides = <index_t *> cA.strides
output_info.shape = <size_t *> cA.shape

if data.dtype == np.float64:
if c_wt.double_downcoef_axis(<double *> data.data, data_info,
<double *> cA.data, output_info,
w.w, axis, common.COEF_APPROX, _mode):
raise RuntimeError("C wavelet transform failed")
if c_wt.double_downcoef_axis(<double *> data.data, data_info,
<double *> cD.data, output_info,
w.w, axis, common.COEF_DETAIL, _mode):
raise RuntimeError("C wavelet transform failed")
elif data.dtype == np.float32:
if c_wt.float_downcoef_axis(<float *> data.data, data_info,
<float *> cA.data, output_info,
w.w, axis, common.COEF_APPROX, _mode):
raise RuntimeError("C wavelet transform failed")
if c_wt.float_downcoef_axis(<float *> data.data, data_info,
<float *> cD.data, output_info,
w.w, axis, common.COEF_DETAIL, _mode):
raise RuntimeError("C wavelet transform failed")
else:
raise TypeError("Array must be floating point, not {}"
.format(data.dtype))
return (cA, cD)


# TODO: Use idwt rather than upcoef, which requires `mode` but not `take`
cpdef idwt_axis(np.ndarray coefs_a, np.ndarray coefs_d, object wavelet,
object mode='sym', unsigned int axis=0):
cdef Wavelet w = c_wavelet_from_object(wavelet)
cdef common.ArrayInfo a_info, d_info, output_info
cdef np.ndarray output
cdef np.dtype output_dtype
cdef size_t[::1] output_shape
cdef common.MODE _mode = _try_mode(mode)

if coefs_a is not None:
if coefs_d is not None and coefs_d.dtype.itemsize > coefs_a.dtype.itemsize:
coefs_a = coefs_a.astype(_check_dtype(coefs_d), copy=False)
else:
coefs_a = coefs_a.astype(_check_dtype(coefs_a), copy=False)
a_info.ndim = coefs_a.ndim
a_info.strides = <index_t *> coefs_a.strides
a_info.shape = <size_t *> coefs_a.shape
if coefs_d is not None:
if coefs_a is not None and coefs_a.dtype.itemsize > coefs_d.dtype.itemsize:
coefs_d = coefs_d.astype(_check_dtype(coefs_a), copy=False)
else:
coefs_d = coefs_d.astype(_check_dtype(coefs_d), copy=False)
d_info.ndim = coefs_d.ndim
d_info.strides = <index_t *> coefs_d.strides
d_info.shape = <size_t *> coefs_d.shape

if coefs_a is not None:
output_shape = (<size_t [:coefs_a.ndim]> <size_t *> coefs_a.shape).copy()
output_shape[axis] = common.idwt_buffer_length(coefs_a.shape[axis],
w.rec_len, _mode)
output_dtype = coefs_a.dtype
elif coefs_d is not None:
output_shape = (<size_t [:coefs_d.ndim]> <size_t *> coefs_d.shape).copy()
output_shape[axis] = common.idwt_buffer_length(coefs_d.shape[axis],
w.rec_len, _mode)
output_dtype = coefs_d.dtype
else:
return None;

output = np.empty(output_shape, output_dtype)

output_info.ndim = output.ndim
output_info.strides = <index_t *> output.strides
output_info.shape = <size_t *> output.shape

if output.dtype == np.float64:
if c_wt.double_idwt_axis(<double *> coefs_a.data if coefs_a is not None else NULL,
&a_info if coefs_a is not None else NULL,
<double *> coefs_d.data if coefs_d is not None else NULL,
&d_info if coefs_d is not None else NULL,
<double *> output.data, output_info,
w.w, axis, _mode):
raise RuntimeError("C inverse wavelet transform failed")
if output.dtype == np.float32:
if c_wt.float_idwt_axis(<float *> coefs_a.data if coefs_a is not None else NULL,
&a_info if coefs_a is not None else NULL,
<float *> coefs_d.data if coefs_d is not None else NULL,
&d_info if coefs_d is not None else NULL,
<float *> output.data, output_info,
w.w, axis, _mode):
raise RuntimeError("C inverse wavelet transform failed")

return output


def dwt_coeff_len(data_len, filter_len, mode='sym'):
"""
dwt_coeff_len(data_len, filter_len, mode='sym')
Expand Down Expand Up @@ -745,16 +858,16 @@ def _try_mode(mode):
raise TypeError("Invalid mode: {0}".format(str(mode)))


def _check_dtype(data):
cdef np.dtype _check_dtype(data):
"""Check for cA/cD input what (if any) the dtype is."""
cdef np.dtype dt
try:
dt = data.dtype
if not dt == np.float32:
if dt not in (np.float64, np.float32):
# integer input was always accepted; convert to float64
dt = np.float64
dt = np.dtype('float64')
except AttributeError:
dt = np.float64

dt = np.dtype('float64')
return dt


Expand Down
20 changes: 19 additions & 1 deletion pywt/src/c_wt.pxd
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
# Copyright (c) 2006-2012 Filip Wasilewski <http://en.ig.ma/>
# See COPYING for license details.

from common cimport MODE, index_t
from common cimport MODE, index_t, ArrayInfo, Coefficient
from wavelet cimport Wavelet


cdef extern from "wt.h":
# Cython does not know the 'restrict' keyword
cdef int double_downcoef_axis(const double * const input, const ArrayInfo input_info,
double * const output, const ArrayInfo output_info,
const Wavelet * const wavelet, const size_t axis,
const Coefficient detail, const MODE mode)
cdef int double_idwt_axis(const double * const coefs_a, const ArrayInfo * const a_info,
const double * const coefs_d, const ArrayInfo * const d_info,
double * const output, const ArrayInfo output_info,
const Wavelet * const wavelet, const size_t axis,
const MODE mode)
cdef int double_dec_a(const double * const input, const size_t input_len,
const Wavelet * const wavelet,
double * const output, const size_t output_len,
Expand Down Expand Up @@ -35,6 +44,15 @@ cdef extern from "wt.h":
double output[], index_t output_len, int level)


cdef int float_downcoef_axis(const float * const input, const ArrayInfo input_info,
float * const output, const ArrayInfo output_info,
const Wavelet * const wavelet, const size_t axis,
const Coefficient detail, const MODE mode)
cdef int float_idwt_axis(const float * const coefs_a, const ArrayInfo * const a_info,
const float * const coefs_d, const ArrayInfo * const d_info,
float * const output, const ArrayInfo output_info,
const Wavelet * const wavelet, const size_t axis,
const MODE mode)
cdef int float_dec_a(const float * const input, const size_t input_len,
const Wavelet * const wavelet,
float * const output, const size_t output_len,
Expand Down
10 changes: 10 additions & 0 deletions pywt/src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
#include <intrin.h>
#endif

typedef struct {
size_t * shape;
index_t * strides;
size_t ndim;
} ArrayInfo;

typedef enum {
COEF_APPROX = 0,
COEF_DETAIL = 1,
} Coefficient;

/* Signal extension modes */
typedef enum {
Expand Down
9 changes: 9 additions & 0 deletions pywt/src/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ cdef extern from "common.h":
cdef void* wtcalloc(long len, long size)
cdef void wtfree(void* ptr)

ctypedef struct ArrayInfo:
size_t * shape
index_t * strides
size_t ndim

ctypedef enum Coefficient:
COEF_APPROX = 0
COEF_DETAIL = 1

ctypedef enum MODE:
MODE_INVALID = -1
MODE_ZEROPAD = 0
Expand Down

0 comments on commit e3bab45

Please sign in to comment.