<a href="https://colab.research.google.com/github/Raiden-Makoto/swiftualizer/blob/main/TransformerModel/SwiftNETV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SwiftNET:
**Keras + Tensorflow implementation of WaveNET specifically trained to generate pop songs in the style of Taylor Swift**

In [12]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.data as tf_data
import keras
from keras import layers
from keras import ops
from keras.initializers import GlorotUniform, Zeros

### Helper Operation for the main **SwiftNET** Model
Code is sourced from [here](https://github.com/kokeshing/WaveNet-tf2/blob/master/model/module.py) but modified to use `Keras 3`.

A causal convolution layer is a type of 1D convolution designed for sequential data where the output at time step $t$ depends only on inputs from time steps
$ \leq t$. This structure ensures that there is no "leakage" of future information into the past, making it useful for autoregressive models like WaveNet and time-series forecasting.

In [13]:
class CausalConvolutionLayer(layers.Conv1D):
    def __init__(
            self,
            filters,
            kernel_size,
            strides=1,
            padding='causal',
            dilation_rate=1,
            residual_channels=None,
            *args,
            **kwargs
        ):
        super().__init__(
            filters,
            kernel_size,
            strides=strides,
            padding=padding,
            dilation_rate=dilation_rate
        )
        self.k = kernel_size
        self.d = dilation_rate
        if kernel_size > 1:
            self.queue_len = kernel_size + (kernel_size - 1) * (dilation_rate - 1)
            self.queue_dim = residual_channels
            self.init_queue()

    def build(self, input_shape):
        super().build(input_shape)
        self.linearized_weights = ops.cast(
            ops.reshape(self.kernel, [-1, self.filters]),
            dtype=tf.float32
        )

    def call(self, inputs, is_synthesis=False):
        if not is_synthesis: return super().call(inputs)
        if self.k > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = ops.concatenate(
                [self.queue, ops.expand_dims(inputs[:, -1, :], axis=1)],
                axis=1
            )
            if self.d > 1: inputs = self.queue[:, 0::self.d, :]
            else: inputs = self.queue
        outputs = ops.matmul(ops.reshape(inputs, [1, -1]), self.linearized_weights)
        outputs = keras.backend.bias_add(outputs, self.bias)
        return tf.reshape(outputs, [-1, 1, self.filters])

    def init_queue(self):
        self.queue = ops.zeros([1, self.queue_len, self.queue_dim], dtype=tf.float32)

In [15]:
class ResidualConv1DGLU(keras.Model):
    """conv1d + GLU => add condition => residual add + skip connection"""
    def __init__(
            self,
            residual_channels,
            gate_channels,
            kernel_size,
            skip_out_channels=None,
            dilation_rate=1,
            **kwargs
        ):
        super().__init__()
        self.residual_channels = residual_channels
        if skip_out_channels is None: skip_out_channels = residual_channels

        self.dilated_conv = CausalConvolutionLayer(
            gate_channels,
            kernel_size=kernel_size,
            padding='causal',
            dilation_rate=dilation_rate,
            residual_channels=residual_channels
        )
        self.conv_c = CausalConvolutionLayer(
            gate_channels,
            kernel_size=1,
            padding='causal'
        )
        self.conv_skip = CausalConvolutionLayer(
            skip_out_channels,
            kernel_size=1,
            padding='causal'
        )
        self.conv_out = CausalConvolutionLayer(
            residual_channels,
            kernel_size=1,
            padding='causal'
        )

    @tf.function
    def call(self, inputs, c):
        x = self.dilated_conv(inputs)
        x_tanh, x_sigmoid = ops.split(x, indices_or_sections=2, axis=2)
        c = self.conv_c(c)
        c_tanh, c_sigmoid = tf.split(c, indices_or_sections=2, axis=2)

        x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid
        x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigmoid)

        s = self.conv_skip(x)
        x = self.conv_out(x)

        x = x + inputs

        return x, s

    def init_queue(self):
        self.dilated_conv.init_queue()

    def synthesis_feed(self, inputs, c):
        x = self.dilated_conv(inputs, is_synthesis=True)
        x_tanh, x_sigmoid = tf.split(x, num_or_size_splits=2, axis=2)

        c = self.conv_c(c, is_synthesis=True)
        c_tanh, c_sigmoid = tf.split(c, num_or_size_splits=2, axis=2)

        x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid
        x = tf.nn.tanh(x_tanh) *keras.activations.sigmoid(x_sigmoid)
        s = self.conv_skip(x, is_synthesis=True)
        x = self.conv_out(x, is_synthesis=True)
        x = x + inputs
        return x, s

## Create UpSampling Layers

In WaveNet, upsampling layers are be applied at the final output layer to expand or refine the predictions (e.g., generating high-frequency audio details).

In [17]:
class UpsampleConv(tf.keras.Model):
    def __init__(self, rate, **kwargs):
        super().__init__()
        self.upsampling = layers.UpSampling2D(
            size=(1, rate),
            interpolation='nearest'
        )
        self.conv = layers.Conv2D(
            filters=1,
            kernel_size=(1, rate * 2 + 1),
            padding='same',
            use_bias=False,
            kernel_initializer=tf.constant_initializer(1. / (rate * 2 + 1))
        )

    @tf.function
    def call(self, x):
        return self.conv(self.upsampling(x))

In [18]:
class UpsampleNetwork(tf.keras.Model):
    def __init__(self, upsample_scales, **kwargs):
        super().__init__()
        self.upsample_layers = [UpsampleConv(scale) for scale in upsample_scales]

    @tf.function
    def call(self, feat):
        for layer in self.upsample_layers:
            feat = layer(feat)
        return feat

## Create the SwiftNET Model