<a href="https://colab.research.google.com/github/AustenLamacraft/learning-wavelets/blob/master/wavelet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

# !pip install -q tensorflow-gpu==2.0.0-beta1
# !pip install -q tensorflow-probability
import tensorflow as tf
import tensorflow_probability as tfp
import math
import numpy as np

In [3]:
tf.keras.backend.clear_session()  # For easy reset of notebook state.

# Helper functions

In [4]:
def convolve_circular(a, b):
    '''
    a: vector
    b: kernel (must have odd length!)
    Naive implementation of circular convolution. The middle entry of b corresponds to the coefficient of z^0:
    b[0] b[1] b[2] b[3] b[4]
    z^-2 z^-1 z^0  z^1  z^2
    '''
    len_a = int(tf.size(a))
    len_b = int(tf.size(b))
    result = np.zeros(len_a)
    for i in range(0, len_a):
        for j in range(0, len_b):
            result[i] += b[-j-1] * a[(i + (j - len_b//2)) % len_a]
    return tf.constant(result, dtype='float32')

In [5]:
convolve_circular(tf.constant([1,2,3,4,]), tf.constant([-1, 0, 1]))

<tf.Tensor: id=154, shape=(4,), dtype=float32, numpy=array([ 2., -2., -2.,  2.], dtype=float32)>

In [6]:
# def convolve_circular_currently_unused(a, b):
#   '''
#   a: vector
#   b: kernel
#   Requires that 2*tf.size(b) <= tf.size(a). If this is not satisfied, overlap
#   will occur in the convolution.
#   '''
#   b_padding = tf.constant([[0, int(tf.size(a) - tf.size(b))]])
#   b_padded = tf.pad(b, b_padding, "CONSTANT")
#   a_fft = tf.signal.fft(tf.complex(a, 0.0))
#   b_fft = tf.signal.fft(tf.complex(b_padded, 0.0))
#   ifft = tf.signal.ifft(a_fft * b_fft)
#   return tf.cast(tf.math.real(ifft), 'float32')

# Layers

In [18]:
class LazyWavelet(tfp.bijectors.Bijector):
  '''
  This layer corresponds to the downsampling step of a single-step wavelet transform.
  Input to _forward: 1D tensor of even length.
  Output of _forward: Two stacked 1D tensors of the form [[even x componenents], [odd x components]]
  See https://uk.mathworks.com/help/wavelet/ug/lifting-method-for-constructing-wavelets.html
  for notation.
  '''
  def __init__(self,
               validate_args=False,
               name="lazy_wavelet"):
        super().__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        name=name)

  def _forward(self, x):
    x_evens = x[0::2]
    x_odds = x[1::2]
    return tf.stack([x_evens, x_odds])

  def _inverse(self, y):
    x_evens = y[0,:]
    x_odds = y[1,:]
    x = tf.reshape(tf.stack([x_evens, x_odds], axis=1), shape=[-1])  # interleave evens and odds
    return x

  def _inverse_log_det_jacobian(self, y):
    return 0  # QUESTION: Are these log determinants correct?

  def _forward_log_det_jacobian(self, x):
    return 0  # QUESTION: Are these log determinants correct?

In [113]:
class Lifting(tfp.bijectors.Bijector):
  '''
  This layer corresponds to two elementary matrices of the polyphase matrix of a single-step wavelet transform.
  Input to _forward: Two stacked 1D tensors of the form [[lowpass wavelet coefficients], [highpass wavelet coefficients]],
      i.e. the output of LazyWavelet or another Lifting layer.
  Output of _forward: Two stacked 1D tensors of the form [[lowpass wavelet coefficients], [highpass wavelet coefficients]].
  See https://uk.mathworks.com/help/wavelet/ug/lifting-method-for-constructing-wavelets.html
  for notation.
  '''
  def __init__(self,
               validate_args=False,
               name="lifting",
               n_lifting_coeffs=3,
               P_coeff=tf.random.uniform(shape=(3,)),
               U_coeff=tf.random.uniform(shape=(3,))):
    super().__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        name=name)
    self.n_lifting_coeffs = n_lifting_coeffs
    self.P_coeff = tf.Variable(initial_value=P_coeff)  # P: predict (primal lifting)
    self.U_coeff = tf.Variable(initial_value=U_coeff)  # U: update (dual lifting)

  def _forward(self, x):
    x_evens = x[0,:]
    x_odds = x[1,:]
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    detail = x_odds - evens_conv_P
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    average = x_evens + detail_conv_U
    return tf.stack([average, detail])

  def _inverse(self, y):
    average = y[0,:]
    detail = y[1,:]
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    x_evens = average - detail_conv_U
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    x_odds = evens_conv_P + detail
    x = tf.stack([x_evens, x_odds])
    return x

  def _inverse_log_det_jacobian(self, y):
    return 0  # QUESTION: Are these log determinants correct?

  def _forward_log_det_jacobian(self, x):
    return 0  # QUESTION: Are these log determinants correct?

# Tests

## LazyWavelet layer

In [40]:
def test_lazy_wavelet_layer(x, y_expected):
    lazy_layer = LazyWavelet()
    y_result = lazy_layer._forward(x)
    assert tf.reduce_all(tf.math.equal(y_expected, y_result))
    
test_lazy_wavelet_layer(tf.constant([1,2,3,4,5,6,7,8], dtype='float32'), tf.constant([[1,3,5,7], [2,4,6,8]], dtype='float32'))

In [112]:
def test_lazy_wavelet_layer_inverse():
    lazy_layer = LazyWavelet()
    x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')
    y = lazy_layer._forward(x)
    y_inv = lazy_layer._inverse(y)
    assert tf.norm(x-y_inv) < 1e-06

test_lazy_wavelet_layer_inverse()

## Lifting layer

In [114]:
def test_lifting_layer(x, P_coeff, U_coeff, y_expected):
    lifting = Lifting(P_coeff=P_coeff, U_coeff=U_coeff)
    y_result = lifting.forward(x)
    assert tf.reduce_all(tf.math.equal(y_expected, y_result))

In [115]:
# Haar wavelet, verified with MATLAB:
# y_expected = dwt(x, [1/2, 1/2], [1, -1]))
test_lifting_layer(x=tf.constant([[1,3,5,7], [2,4,6,8]], dtype='float32'),
                   P_coeff=tf.constant([1.]),
                   U_coeff=tf.constant([.5]),
                   y_expected=tf.constant([[1.5, 3.5, 5.5, 7.5], [1.,  1.,  1.,  1.]]))

In [116]:
# Wavelet with lowpass filter h(z) = (1/8) * (2z^3 - z^2 + 2z + 6 - z^-2) and highpass
# filter g(z) = (-1/2)z^-2 - 1/2 + z.
# The output can be verified by doing circular convolution manually. MATLAB's dwt is not
# usable for verification, since it seems to only take causal filters as arguments.
test_lifting_layer(x=tf.constant([[1,3,5,7], [2,4,6,8]], dtype='float32'),
                   P_coeff=tf.constant([0, .5, .5]),
                   U_coeff=tf.constant([.25, .25, 0]),
                   y_expected=tf.constant([[1., 4., 6., 7.], [-2., 2., 2., 2.]]))

In [117]:
def test_lifting_layer_inverse():
    lifting = Lifting()
    x = tf.constant([[1,3,5,7], [2,4,6,8]], dtype='float32')
    y = lifting._forward(x)
    y_inv = lifting._inverse(y)
    assert tf.norm(x-y_inv) < 1e-06
    
test_lifting_layer_inverse()

## Chain(Lifting, LazyWavelet)

In [118]:
def test_chain_of_lifting_and_lazy_wavelet():
    x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')
    y_expected = tf.constant([[1., 4., 6., 7.], [-2., 2., 2., 2.]])
    chain = tfp.bijectors.Chain([Lifting(P_coeff=tf.constant([0, .5, .5]),U_coeff=tf.constant([.25, .25, 0])), LazyWavelet()])
    y_result = chain.forward(x)
    assert tf.reduce_all(tf.math.equal(y_expected, y_result))
    
test_chain_of_lifting_and_lazy_wavelet()

In [119]:
def test_chain_of_lifting_and_lazy_wavelet_inverse():
    x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')
    chain = tfp.bijectors.Chain([Lifting(P_coeff=tf.constant([0, .5, .5]),U_coeff=tf.constant([.25, .25, 0])), LazyWavelet()])
    y = chain.forward(x)
    y_inv = chain._inverse(y)
    assert tf.norm(x-y_inv) < 1e-06
    
test_chain_of_lifting_and_lazy_wavelet_inverse()

# Chaining multiple lifting layers

In [132]:
network = tfp.bijectors.Chain([Lifting(), Lifting(), Lifting(), Lifting(), Lifting(), LazyWavelet()])

In [145]:
x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')

y = network.forward(x)
print(y)

network.inverse(y)

tf.Tensor(
[[-15.473223    -0.96022964  13.144677     4.2084436 ]
 [ -3.5964892   18.506632     0.16573715 -27.894365  ]], shape=(2, 4), dtype=float32)


<tf.Tensor: id=26337, shape=(8,), dtype=float32, numpy=array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32)>

# Polyphase matrix -> wavelet filters

## How to reconstruct lowpass and highpass wavelet filters from a given polyphase matrix.

Assume that we are using a polyphase matrix $M =
  \begin{bmatrix}
    a(z) & b(z) \\
    c(z) & d(z)
  \end{bmatrix}$ to do a wavelet transform of a vector X: $
  \begin{bmatrix}
    \text{lowpass_coefficients}(z) \\
    \text{highpass_coefficients}(z)
  \end{bmatrix} = M
  \begin{bmatrix}
    X_{e}(z) \\
    X_{o}(z)
  \end{bmatrix}$
  
Let the lowpass filter corresponding to M be $h(z) = h_e(z) + z^{-1}h_o(z^2)$.

Let the highpass filter corresponding to M be $g(z) = g_e(z) + z^{-1}g_o(z^2)$.

Then, using formulas 4.25 and 4.26 from "Wavelet and Filter Banks"[Strang, Nguyen], we get $$M =
  \begin{bmatrix}
    a(z) & b(z) \\
    c(z) & d(z)
  \end{bmatrix} = 
  \begin{bmatrix}
    h_e(z) & z^{-1}h_o(z) \\
    g_e(z) & z^{-1}g_o(z)
  \end{bmatrix}
  $$
and thus $$h(z) = a(z^2) + zb(z^2) \\ g(z) = c(z^2) + zd(z^2)$$

  