<a href="https://colab.research.google.com/github/AustenLamacraft/learning-wavelets/blob/master/wavelet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
from __future__ import absolute_import, division, print_function, unicode_literals

# !pip install -q tensorflow-gpu==2.0.0-beta1
# !pip install -q tensorflow-probability
import tensorflow as tf
import tensorflow_probability as tfp
import math
import numpy as np

In [271]:
tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [289]:
def convolve_circular(a, b):
    '''
    a: vector
    b: kernel (must have odd length!)
    Naive implementation of circular convolution. The middle entry of b corresponds to the coefficient of z^0:
    b[0] b[1] b[2] b[3] b[4]
    z^-2 z^-1 z^0  z^1  z^2
    '''
    len_a = int(tf.size(a))
    len_b = int(tf.size(b))
    result = np.zeros(len_a)
    for i in range(0, len_a):
        for j in range(0, len_b):
            result[i] += b[-j-1] * a[(i + (j - len_b//2)) % len_a]
    return tf.constant(result, dtype='float32')

In [278]:
convolve_circular(tf.constant([1,2,3,4,]), tf.constant([-1, 0, 1]))

convolving  [1 2 3 4]  with  [-1  0  1]


<tf.Tensor: id=28320, shape=(4,), dtype=float32, numpy=array([ 2., -2., -2.,  2.], dtype=float32)>

In [290]:
# def convolve_circular_currently_unused(a, b):
#   '''
#   a: vector
#   b: kernel
#   Requires that 2*tf.size(b) <= tf.size(a). If this is not satisfied, overlap
#   will occur in the convolution.
#   '''
#   b_padding = tf.constant([[0, int(tf.size(a) - tf.size(b))]])
#   b_padded = tf.pad(b, b_padding, "CONSTANT")
#   a_fft = tf.signal.fft(tf.complex(a, 0.0))
#   b_fft = tf.signal.fft(tf.complex(b_padded, 0.0))
#   ifft = tf.signal.ifft(a_fft * b_fft)
#   return tf.cast(tf.math.real(ifft), 'float32')

In [374]:
class Lifting(tfp.bijectors.Bijector):
  '''
  Input to _forward: 1D tensor of even length.
  Output of _forward: Two stacked 1D tensors of the form [[lowpass wavelet coefficients], [highpass wavelet coefficients]]
  See https://uk.mathworks.com/help/wavelet/ug/lifting-method-for-constructing-wavelets.html
  for notation.
  '''
  def __init__(self,
               validate_args=False,
               name="lifting",
               n_lifting_coeffs=2,
               P_coeff=tf.random.uniform(shape=(2,)),
               U_coeff=tf.random.uniform(shape=(2,))):
    super(Lifting, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        name=name)
    self.n_lifting_coeffs = n_lifting_coeffs
    self.P_coeff = tf.Variable(initial_value=P_coeff)  # P: predict (primal lifting)
    self.U_coeff = tf.Variable(initial_value=U_coeff)  # U: update (dual lifting)

  def _forward(self, x):
    x_evens = x[0::2]
    x_odds = x[1::2]
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    detail = x_odds - evens_conv_P
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    average = x_evens + detail_conv_U
    return tf.stack([average, detail])

  def _inverse(self, y):
    average = y[0,:]
    detail = y[1,:]
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    x_evens = average - detail_conv_U
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    x_odds = evens_conv_P + detail
    x = tf.reshape(tf.stack([x_evens, x_odds], axis=1), shape=[-1])  # interleave evens and odds
    return x

  def _inverse_log_det_jacobian(self, y):
    return 0  # QUESTION: Are these log determinants correct?

  def _forward_log_det_jacobian(self, x):
    return 0  # QUESTION: Are these log determinants correct?

# Tests

In [276]:
def test_lifting_layer(x, P_coeff, U_coeff, y_expected):
    lifting = Lifting(P_coeff=P_coeff, U_coeff=U_coeff)
    y_result = lifting._forward(x)
    assert tf.reduce_all(tf.math.equal(y_expected, y_result))

In [371]:
# Haar wavelet, verified with MATLAB:
# y_expected = dwt(x, [1/2, 1/2], [1, -1]))
test_lifting_layer(x=tf.constant([1,2,3,4,5,6,7,8], dtype='float32'),
                   P_coeff=tf.constant([1.]),
                   U_coeff=tf.constant([.5]),
                   y_expected=tf.constant([[1.5, 3.5, 5.5, 7.5], [1.,  1.,  1.,  1.]]))

In [372]:
# Wavelet with lowpass filter h(z) = (1/8) * (2z^3 - z^2 + 2z + 6 - z^-2) and highpass
# filter g(z) = (-1/2)z^-2 - 1/2 + z.
# The output can be verified by doing circular convolution manually. MATLAB's dwt is not
# usable for verification, since it seems to only take causal filters as arguments.
test_lifting_layer(x=tf.constant([1,2,3,4,5,6,7,8], dtype='float32'),
                   P_coeff=tf.constant([0, .5, .5]),
                   U_coeff=tf.constant([.25, .25, 0]),
                   y_expected=tf.constant([[1., 4., 6., 7.], [-2., 2., 2., 2.]]))

In [373]:
def test_lifting_layer_inverse():
    lifting = Lifting()
    x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')
    y = lifting._forward(x)
    y_inv = lifting._inverse(y)
    assert tf.norm(x-y_inv) < 1e-06  # Occasional floating point errors prevent a perfect equality.
    
test_lifting_layer_inverse()