<a href="https://colab.research.google.com/github/Raiden-Makoto/swiftualizer/blob/main/TransformerModel/SwiftNETV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SwiftNET:
**Keras + Tensorflow implementation of WaveNET specifically trained to generate pop songs in the style of Taylor Swift**

In [2]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.data as tf_data
import keras
from keras import layers

### Helper Operations for the main **SwiftNET** Model

A causal convolution layer is a type of 1D convolution designed for sequential data where the output at time step $t$ depends only on inputs from time steps
$ \leq t$. This structure ensures that there is no "leakage" of future information into the past, making it useful for autoregressive models like WaveNet and time-series forecasting.

In [3]:
def CausalConvolutionLayer(value, filters, kernel_size, dilation, name='causal_conv'):
    with tf.name_scope(name):
        CausalConv = layers.Conv1D(
            filters=filters,
            kernel_size=kernel_size,
            dilation_rate=dilation,
            padding='causal',
        )
        return CausalConv(value)

Companding (short for compressing + expanding) is a technique used in signal processing to reduce the dynamic range of a signal before quantization and restore it afterward. It helps to improve signal quality, reduce quantization noise, and optimize storage or transmission efficiency—especially in audio, speech processing, and telecommunications.

The mu-law transformation is a nonlinear companding algorithm used in digital audio processing and speech compression. It reduces the dynamic range of an audio signal, improving quantization at lower amplitudes while preserving detail in louder signals.

In [4]:
def MuLawEncode(audio, num_quantization_channels):
    with tf.name_scope("MuLawEncode"):
        mu = tf.cast(num_quantization_channels - 1, dtype=tf.float32)
        safe_audio_abs = tf.minimum(tf.abs(audio), 1.0)
        magnitude = tf.math.log1p(mu * safe_audio_abs) / tf.math.log1p(mu)
        signal = tf.sign(audio) * magnitude
        encoded = tf.cast((signal + 1) / 2 * mu + 0.5, dtype=tf.int32)
        return encoded

### Create `tf.Variable`s

In [5]:
def create_variable(name, shape):
    '''Creates a convolution filter variable with Xavier (Glorot) initialization.'''
    with tf.name_scope(name):
        initializer = tf.keras.initializers.GlorotUniform()  # Xavier initialization
        variable = tf.Variable(initializer(shape), name=name, trainable=True)
    return variable

In [6]:
def create_embedding_table(name, shape):
    '''Creates an embedding table, initializing it as an identity matrix if square.'''
    with tf.name_scope(name):
        if shape[0] == shape[1]:
            initial_val = np.identity(n=shape[0], dtype=np.float32)
            return tf.Variable(initial_val, name=name, trainable=True)
        else:
            return create_variable(name, shape)

In [7]:
def create_bias_variable(name, shape):
    '''Create a bias variable with the specified name and shape and initialize
    it to zero.'''
    initializer = tf.constant_initializer(value=0.0, dtype=tf.float32)
    return tf.Variable(initializer(shape=shape), name)

## Create the SwiftNET Model

In [10]:
class SwiftNet(keras.Model):
    def __init__(
        self,
        batch_size,
        dilations,
        filter_width,
        residual_channels,
        dilation_channels,
        skip_channels,
        quantization_channels=2**8,
        use_biases=False,
        scalar_input=False,
        initial_filter_width=32,
        histograms=False,
        global_condition_channels=None,
        global_condition_cardinality=None
    ):
        super(SwiftNet, self).__init__()
        self.batch_size = batch_size
        self.dilations = dilations
        self.filter_width = filter_width
        self.residual_channels = residual_channels
        self.dilation_channels = dilation_channels
        self.quantization_channels = quantization_channels
        self.use_biases = use_biases
        self.skip_channels = skip_channels
        self.scalar_input = scalar_input
        self.initial_filter_width = initial_filter_width
        self.histograms = histograms
        self.global_condition_channels = global_condition_channels
        self.global_condition_cardinality = global_condition_cardinality

        self.receptive_field = self.calculate_receptive_field(
            self.filter_width,
            self.dilations,
            self.scalar_input,
            self.initial_filter_width
        )
        self.variables = self._create_variables()

    def calculate_receptive_field(self, filter_width, dilations, scalar_input, initial_filter_width):
        receptive_field = (filter_width - 1) * sum(dilations) + 1
        if scalar_input: receptive_field += initial_filter_width - 1
        else: receptive_field += filter_width - 1
        return receptive_field

    def _create_variables(self):
        variables = {}