<a href="https://colab.research.google.com/github/Raiden-Makoto/swiftualizer/blob/main/TransformerModel/SwiftNETV1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SwiftNET:
**Keras + Tensorflow implementation of WaveNET specifically trained to generate pop songs in the style of Taylor Swift**

In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.data as tf_data
import keras
from keras import layers
from keras import ops
from keras.initializers import GlorotUniform, Zeros

### Helper Operation for the main **SwiftNET** Model
Code is sourced from [here](https://github.com/kokeshing/WaveNet-tf2/blob/master/model/module.py) but modified to use `Keras 3`.

A causal convolution layer is a type of 1D convolution designed for sequential data where the output at time step $t$ depends only on inputs from time steps
$ \leq t$. This structure ensures that there is no "leakage" of future information into the past, making it useful for autoregressive models like WaveNet and time-series forecasting.

In [None]:
class CausalConvolutionLayer(layers.Conv1D):
    def __init__(
            self,
            filters,
            kernel_size,
            strides=1,
            padding='causal',
            dilation_rate=1,
            residual_channels=None,
            *args,
            **kwargs
        ):
        super().__init__(
            filters,
            kernel_size,
            strides=strides,
            padding=padding,
            dilation_rate=dilation_rate
        )
        self.k = kernel_size
        self.d = dilation_rate
        if kernel_size > 1:
            self.queue_len = kernel_size + (kernel_size - 1) * (dilation_rate - 1)
            self.queue_dim = residual_channels
            self.init_queue()

    def build(self, input_shape):
        super().build(input_shape)
        self.linearized_weights = ops.cast(
            ops.reshape(self.kernel, [-1, self.filters]),
            dtype=tf.float32
        )

    def call(self, inputs, is_synthesis=False):
        if not is_synthesis: return super().call(inputs)
        if self.k > 1:
            self.queue = self.queue[:, 1:, :]
            self.queue = ops.concatenate(
                [self.queue, ops.expand_dims(inputs[:, -1, :], axis=1)],
                axis=1
            )
            if self.d > 1: inputs = self.queue[:, 0::self.d, :]
            else: inputs = self.queue
        outputs = ops.matmul(ops.reshape(inputs, [1, -1]), self.linearized_weights)
        outputs = keras.backend.bias_add(outputs, self.bias)
        return tf.reshape(outputs, [-1, 1, self.filters])

    def init_queue(self):
        self.queue = ops.zeros([1, self.queue_len, self.queue_dim], dtype=tf.float32)

In [None]:
class ResidualConv1DGLU(keras.Model):
    """conv1d + GLU => add condition => residual add + skip connection"""
    def __init__(
            self,
            residual_channels,
            gate_channels,
            kernel_size,
            skip_out_channels=None,
            dilation_rate=1,
            **kwargs
        ):
        super().__init__()
        self.residual_channels = residual_channels
        if skip_out_channels is None: skip_out_channels = residual_channels

        self.dilated_conv = CausalConvolutionLayer(
            gate_channels,
            kernel_size=kernel_size,
            padding='causal',
            dilation_rate=dilation_rate,
            residual_channels=residual_channels
        )
        self.conv_c = CausalConvolutionLayer(
            gate_channels,
            kernel_size=1,
            padding='causal'
        )
        self.conv_skip = CausalConvolutionLayer(
            skip_out_channels,
            kernel_size=1,
            padding='causal'
        )
        self.conv_out = CausalConvolutionLayer(
            residual_channels,
            kernel_size=1,
            padding='causal'
        )

    @tf.function
    def call(self, inputs, c):
        x = self.dilated_conv(inputs)
        x_tanh, x_sigmoid = ops.split(x, indices_or_sections=2, axis=2)
        c = self.conv_c(c)
        c_tanh, c_sigmoid = tf.split(c, indices_or_sections=2, axis=2)

        x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid
        x = tf.nn.tanh(x_tanh) * tf.nn.sigmoid(x_sigmoid)

        s = self.conv_skip(x)
        x = self.conv_out(x)

        x = x + inputs

        return x, s

    def init_queue(self):
        self.dilated_conv.init_queue()

    def synthesis_feed(self, inputs, c):
        x = self.dilated_conv(inputs, is_synthesis=True)
        x_tanh, x_sigmoid = tf.split(x, num_or_size_splits=2, axis=2)

        c = self.conv_c(c, is_synthesis=True)
        c_tanh, c_sigmoid = tf.split(c, num_or_size_splits=2, axis=2)

        x_tanh, x_sigmoid = x_tanh + c_tanh, x_sigmoid + c_sigmoid
        x = tf.nn.tanh(x_tanh) *keras.activations.sigmoid(x_sigmoid)
        s = self.conv_skip(x, is_synthesis=True)
        x = self.conv_out(x, is_synthesis=True)
        x = x + inputs
        return x, s

## Create UpSampling Layers
In WaveNet, upsampling layers are be applied at the final output layer to expand or refine the predictions (e.g., generating high-frequency audio details).

In [None]:
class UpsampleConv(tf.keras.Model):
    def __init__(self, rate, **kwargs):
        super().__init__()
        self.upsampling = layers.UpSampling2D(
            size=(1, rate),
            interpolation='nearest'
        )
        self.conv = layers.Conv2D(
            filters=1,
            kernel_size=(1, rate * 2 + 1),
            padding='same',
            use_bias=False,
            kernel_initializer=tf.constant_initializer(1. / (rate * 2 + 1))
        )

    @tf.function
    def call(self, x):
        return self.conv(self.upsampling(x))

In [None]:
class UpsampleNetwork(tf.keras.Model):
    def __init__(self, upsample_scales, **kwargs):
        super().__init__()
        self.upsample_layers = [UpsampleConv(scale) for scale in upsample_scales]

    @tf.function
    def call(self, feat):
        for layer in self.upsample_layers:
            feat = layer(feat)
        return feat

## Mu Law Quantization
Companding (short for compressing + expanding) is a technique used in signal processing to reduce the dynamic range of a signal before quantization and restore it afterward.

It helps to improve signal quality, reduce quantization noise, and optimize storage or transmission efficiency—especially in audio, speech processing, and telecommunications.

The mu-law transformation is a nonlinear companding algorithm used in digital audio processing and speech compression. It reduces the dynamic range of an audio signal, improving quantization at lower amplitudes while preserving detail in louder signals.

In [None]:
def MuLawQuantize(x, quantization_channels=255):
    x = ops.sign(x) * ops.log1p(quantization_channels * ops.abs(x)) / ops.log1p(quantization_channels)
    return quantization_channels * ops.floor(0.5 + x)

In [15]:
def InverseMuLawQuantize(x, quantization_channels=255):
    x = 2 * ops.cast(x, dtype=np.float32) / quantization_channels - 1
    return ops.sign(x) * (1.0 / quantization_channels) * ((1.0 + quantization_channels) ** ops.abs(x) - 1.0)

## Create the SwiftNET Model

In [None]:
class SwiftNET(keras.Model):
    def __init__(self, num_mel, upsample_scales, **kwargs):
        super().__init__()
        self.upsampler = UpsampleNetwork(upsample_scales)
        self.initial = CausalConvolutionLayer(
            filters = 128,
            kernel_size = 1,
            padding = 'causal'
        )
        self.residual_blocks = []
        for _ in range(2):
            for exponent in range(10):
                self.residual_blocks.append(
                    ResidualConv1DGLU(
                        residual_channels = 128,
                        gate_channels = 256,
                        kernel_size = 3,
                        skip_out_channels = 128,
                        dilation_rate = 2 ** exponent
                    )
                )
        self.postprocessing = [
            layers.ReLU(),
            layers.Conv1D(
                filters = 128,
                kernel_size = 1,
                padding = 'causal'
            ),
            layers.ReLU(),
            layers.Conv1D(
                filters = 256,
                kernel_size = 1,
                padding = 'causal'
            )
        ]

    def init_queue(self):
        for block in self.residual_blocks:
            block.init_queue()

    @tf.function
    def call(self, inputs):
        c = ops.expand_dims(c, axis=-1)
        c = self.upsample_network(c)
        c = ops.transpose(ops.squeeze(c, axis=-1), axes=(0, 2, 1))
        x = self.initial(inputs)
        skips = None
        for block in self.residual_blocks:
            x, h = block(x, c)
            if skips is None: skips = h
            else: skips = skips + h
        x = skips
        for layer in self.postprocessing:
            x = layer(x)
        return x

### Parameters

In [None]:
UPSAMPLE_SCALES = [3, 7, 21]

SEQ_LEN = 44100
SR = 44100
NUM_MELS = 128
N_FFT = 2048
HOP_SIZE = 441
WIN_SIZE = 2048

LR = 1e-4
DECAY_RATE = 0.8
DECAY_STEPS = int(4e5)
EPOCHS = 1000
BATCH_SIZE = 9

N_TEST_SAMPLE = 5
SAVE_IVAL = 50

## Loading Dataset: Utilities
`Librosa` is a Python library specifically designed to handle `.wav` files we will be working with

In [None]:
import librosa
from scipy.io import wavfile

In [None]:
def files_to_list(filename):
    with open(filename, encoding="utf-8") as f:
        files = f.readlines()
    files = [file.rstrip() for file in files]
    return files

In [17]:
def load_wav(path, sampling_rate: int=SR):
    """Loads a .wav file. Uses the default sampling rate of 44100 for mono audio."""
    wav = librosa.load(path, sr=sampling_rate, mono=True)[0]
    return wav

In [18]:
def trim_silence(
        wav,
        top_db: float=40.0,
        fft_size: int=2048,
        hop_size: int=HOP_SIZE
    ):
    """
    Trims the silence from an audio signal based on a threshold of loudness.

    Args:
        wav (ndarray): The input audio waveform (1D numpy array).
        top_db (float): The threshold (in decibels) below reference to consider as silence.
        fft_size (int): The size of the FFT window used in the trimming process.
        hop_size (int): The hop size (or stride) used in the trimming process.

    Returns:
        ndarray: The trimmed audio signal.
    """
    trimmed_audio, _ =  librosa.effects.trim(
        wav,
        top_db=top_db,
        frame_length=fft_size,
        hop_length=hop_size
    )
    return trimmed_audio

In [None]:
normalize = lambda wav: librosa.util.normalize(wav)

In [19]:
def save_wav(wav, path: str, sr: int=SR):
    """Saves a .wav file"""
    wav *= 32767 / max(0.0001, np.max(np.abs(wav)))
    wavfile.write(path, sr, wav.astype(np.int16))

In [20]:
def MelSpectrogram(
    wav,
    sampling_rate: int = SR,
    num_mels: int = NUM_MELS,
    n_fft: int = N_FFT,
    hop_size: int = HOP_SIZE,
    win_size: int = WIN_SIZE
):
    d = librosa.stft(y=wav, n_fft=n_fft, hop_length=hop_size, win_length=win_size, pad_mode='constant')
    mel_filter = librosa.filters.mel(sampling_rate, n_fft, n_mels=num_mels)
    s = np.dot(mel_filter, np.abs(d))
    return np.log10(np.maximum(s, 1e-5))