In [2]:
import tensorflow as tf
from tensorflow import keras

class wavenetblock(tf.keras.layers.Layer):
  def __init__(self, filters, filter_size, dilation):
    super(wavenetblock, self).__init__()
    self.conv_s = tf.keras.layers.Conv1D(filters, filter_size, padding='same', dilation_rate=dilation)
    self.conv_t = tf.keras.layers.Conv1D(filters, filter_size, padding='same', dilation_rate=dilation)
    self.tanh = tf.keras.layers.Activation('tanh')
    self.sigmoid = tf.keras.layers.Activation('sigmoid')
    self.multiply = tf.keras.layers.Multiply()
    self.conv1 = tf.keras.layers.Conv1D(filters, filter_size, padding='same')
    self.conv2 = tf.keras.layers.Conv1D(filters, filter_size, padding='same')
    self.add = tf.keras.layers.Add()

  def call(self, inputs):
    xs = self.conv_s(inputs)
    xt = self.conv_t(inputs)
    x_tanh = self.tanh(xt)
    x_sigmoid = self.sigmoid(xs)
    x_multiply = self.multiply([x_tanh, x_sigmoid])
    skip = self.conv1(x_multiply)
    res = self.conv2(x_multiply)
    out = self.add([res, inputs])

    return out, skip

def build_wavenet(input_shape, filters, filter_size, dilations, n_steps_out):
  inputs = tf.keras.layers.Input(shape=input_shape)
  res = inputs
  skip_connections = []
  for dilation in dilations:
    x, skip = wavenetblock(filters, filter_size, dilation)(res)
    skip_connections.append(skip)
    res = tf.keras.layers.Add()([x])
  x = tf.keras.layers.Add()(skip_connections)
  x = tf.keras.layers.Activation('relu')(x)
  x = tf.keras.layers.Conv1D(filters, 1, padding='same')(x)
  x = tf.keras.layers.Activation('relu')(x)
  x = tf.keras.layers.Conv1D(filters, 1, padding='same')(x)
  outputs = tf.keras.layers.Conv1D(6, 1, padding='same')(x)
  model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

  return model


input_shape = (12, 1)
filters = 128
n_steps_out = 6
filter_size = 2
dilations = [1, 2, 4, 8, 16, 32]

wavenet_model = build_wavenet(input_shape, filters, filter_size, dilations, n_steps_out)
wavenet_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 12, 1)]      0           []                               
                                                                                                  
 wavenetblock (wavenetblock)    ((None, 12, 128),    66560       ['input_1[0][0]']                
                                 (None, 12, 128))                                                 
                                                                                                  
 add_1 (Add)                    (None, 12, 128)      0           ['wavenetblock[0][0]']           
                                                                                                  
 wavenetblock_1 (wavenetblock)  ((None, 12, 128),    131584      ['add_1[0][0]']              