<a href="https://colab.research.google.com/github/AustenLamacraft/learning-wavelets/blob/master/wavelet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
from __future__ import absolute_import, division, print_function, unicode_literals

# !pip install -q tensorflow-gpu==2.0.0-beta1
# !pip install -q tensorflow-probability
import tensorflow as tf
import tensorflow_probability as tfp
import math
import numpy as np

In [12]:
tf.keras.backend.clear_session()  # For easy reset of notebook state.

In [13]:
def convolve_circular(a, b):
  '''
  a: vector
  b: kernel
  Requires that 2*tf.size(b) <= tf.size(a). If this is not satisfied, overlap
  will occur in the convolution.
  '''
  b_padding = tf.constant([[0, int(tf.size(a) - tf.size(b))]])
  b_padded = tf.pad(b, b_padding, "CONSTANT")
  a_fft = tf.signal.fft(tf.complex(a, 0.0))
  b_fft = tf.signal.fft(tf.complex(b_padded, 0.0))
  ifft = tf.signal.ifft(a_fft * b_fft)
  return tf.cast(tf.math.real(ifft), 'float32')

In [31]:
class Lifting(tfp.bijectors.Bijector):
  '''
  See https://uk.mathworks.com/help/wavelet/ug/lifting-method-for-constructing-wavelets.html
  for notation.
  '''
  def __init__(self, validate_args=False, name="lifting", n_lifting_coeffs=2):
    super(Lifting, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=1,
        name=name)
    self.n_lifting_coeffs = n_lifting_coeffs
    self.P_coeff = tf.Variable(initial_value=tf.random.uniform(shape=(n_lifting_coeffs,)))  # P: predict (primal lifting)
    self.U_coeff = tf.Variable(initial_value=tf.random.uniform(shape=(n_lifting_coeffs,)))  # U: update (dual lifting)

  def _forward(self, x):
    x_evens = x[0::2]
    x_odds = x[1::2]
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    detail = x_odds - evens_conv_P
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    average = x_evens + detail_conv_U
    return tf.stack([average, detail])

  def _inverse(self, y):
    average = y[0,:]
    detail = y[1,:]
    detail_conv_U = convolve_circular(detail, self.U_coeff)
    x_evens = average - detail_conv_U
    evens_conv_P = convolve_circular(x_evens, self.P_coeff)
    x_odds = evens_conv_P + detail
    x = tf.reshape(tf.stack([x_evens, x_odds], axis=1), shape=[-1])  # interleave evens and odds
    return x

  def _inverse_log_det_jacobian(self, y):
    return 0  # QUESTION: Are these log determinants correct?

  def _forward_log_det_jacobian(self, x):
    return 0  # QUESTION: Are these log determinants correct?

In [32]:
lifting = Lifting()

x = tf.constant([1,2,3,4,5,6,7,8], dtype='float32')

y = lifting._forward(x)
print("y = ",y)
y_inv = lifting._inverse(y)
print("y_inv = ", y_inv)

y =  tf.Tensor(
[[-1.7183032  1.3564123  6.046655   6.9930687]
 [-3.8287673  1.2341118  0.4088254 -0.416461 ]], shape=(2, 4), dtype=float32)
y_inv =  tf.Tensor([1. 2. 3. 4. 5. 6. 7. 8.], shape=(8,), dtype=float32)
