In [1]:
import keras
import tensorflow as tf
import scipy.io.wavfile as wavfile
import numpy as np
from tqdm import tqdm
import re
import os

First, let's design the VQ VAE layers, since this is not a part of Tensorflow.

In [2]:

class VectorQuantizer(keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # The `beta` parameter is best kept between [0.25, 2] as per the paper.
        self.beta = beta

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping the embedding_dim intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # Reshape the quantized values back to the original input shape
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings**2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


Now that the VQ VAE is defined, we need to define our encoder and decoder. This can be any model for encoding and decoding. OpenAI's JukeBox project uses noncausal 1-D dilated convolutions, interleaved with downsampling and upsampling 1-D convolutions. So let's use this approach!

In [3]:
def get_encoder(input_shape, conv_filters, latent_dim):
    encoder_inputs = keras.Input(shape=input_shape)
    
    x = keras.layers.Conv1D(conv_filters, 3, activation="relu", strides=2, padding="same")(
        encoder_inputs
    )
    
    x = keras.layers.Conv1D(2 * conv_filters, 3, activation="relu", strides=2, padding="same")(x)
    
    encoder_outputs = keras.layers.Conv1D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")


def get_decoder(input_shape, conv_filters, latent_dim):
    latent_inputs = keras.Input(shape=get_encoder(input_shape, conv_filters, latent_dim).output.shape[1:])
    
    x = keras.layers.Conv1DTranspose(
        2 * conv_filters, 3, activation="relu", strides=2, padding="same"
    )(latent_inputs)
    
    x = keras.layers.Conv1DTranspose(
        conv_filters, 3, activation="relu", strides=2, padding="same"
    )(x)
    
    decoder_outputs = keras.layers.Conv1DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")

def get_vqvae(input_shape, num_embeddings, conv_filters, latent_dim):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    inputs = keras.Input(shape=input_shape)
    
    encoder = get_encoder(input_shape, conv_filters, latent_dim)
    decoder = get_decoder(input_shape, conv_filters, latent_dim)
    
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name="vq_vae")

Let's load from raw wav data and test our network's ability to train.

In [4]:


embedding_size = 2 ** 16 - 1

def load_song_file(song_file: str):
    rate, song_data = wavfile.read(song_file)

    mono_data = song_data

    if song_data.shape[1] == 2:
        mono_data = np.average(song_data, axis=1)
        
    audio = (mono_data / np.max(mono_data))
    audio = (audio * embedding_size) - 2 ** 15
    audio = audio.astype(np.int16)
    return rate, audio
        
        
def get_training_sequences(data, rate, chunk_duration):
    chunk_size = int(rate * chunk_duration)
    
    Xs = []
    for i in range(0, len(data), chunk_size):
        chunk = data[i:i + chunk_size]
        chunk = np.pad(chunk, (0, chunk_size - len(chunk)), mode='constant')
        chunk = tf.one_hot(chunk, depth=embedding_size)
        Xs.append(chunk)
        # Ys.append(chunk[-1])
        
        
    X = np.array(Xs)
    # Y = np.array(Ys)
    
    
    return X

rate = 44_100
sample_song = "data/Wavs/tvari-tokyo-cafe-159065.wav"
_, data = load_song_file(sample_song)

test_duration = 10
sequence_duration = 1
set_samples = int(sequence_duration * rate)
data = data[:set_samples * test_duration]


X = get_training_sequences(data, rate, sequence_duration)

print(X.shape)
print(set_samples)

        

(10, 44100, 65535)
44100


In [5]:
conv_filters = 32
latent_dim = 32
n_batch = 32

vaqae = get_vqvae((set_samples, embedding_size), embedding_size, conv_filters, latent_dim)

optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss = keras.losses.MeanSquaredError()

vaqae.compile(optimizer=optimizer, loss=loss, metrics = ["accuracy"])


vaqae.summary()

# epochs = 1
# x_train = X[:n_train_size]
# vaqae.fit(x_train, x_train, batch_size=n_batches, epochs=epochs)

Model: "vq_vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 44100, 65535)]    0         
                                                                 
 encoder (Functional)        (None, 11025, 32)         6299680   
                                                                 
 vector_quantizer (VectorQu  (None, 11025, 32)         2097120   
 antizer)                                                        
                                                                 
 decoder (Functional)        (None, 44100, 1)          12481     
                                                                 
Total params: 8409281 (32.08 MB)
Trainable params: 8409281 (32.08 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [6]:

class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, input_shape, conv_filters = 32, latent_dim=32, num_embeddings=128, **kwargs):
        super().__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings
        # self.input_shape = input_shape
        # self.conv_filters = conv_filters

        self.vqvae = get_vqvae(input_shape, self.num_embeddings, conv_filters, self.latent_dim)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
        self.optimizer = keras.optimizers.Adam(learning_rate=1e-3)
        
        checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.vqvae)
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint, './checkpoints/Waveforms_with_VQVAE', checkpoint_name="checkpoint", max_to_keep=5)
        

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # Outputs from the VQ-VAE.
            reconstructions = self.vqvae(x)

            # Calculate the losses.
            reconstruction_loss = (
                tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
            )
            total_loss = reconstruction_loss + sum(self.vqvae.losses)

        # Backpropagation.
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Loss tracking.
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

        # Log results.
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }
        
        
conv_filters = 32
latent_dim = 32
n_batch = 1
input_shape = (set_samples, embedding_size)


# x_train = X[:n_train_size * n_batch]

# trainer = VQVAETrainer(1, input_shape, conv_filters, latent_dim, embedding_size)

# for epoch in  tqdm(range(epochs), "Epoch"):
#     for i in tqdm(range(0, x_train.shape[0], n_batch), "Batch"):
        
#         step_x = np.array(x_train[i:n_batch])
#         result = trainer.train_step(step_x)
        
#     # Checkpoints
#     trainer.checkpoint_manager.save()
    
    
    
def fit_vqvae(data, batch_size, epochs):
    trainer = VQVAETrainer(0.2, input_shape, conv_filters, latent_dim, embedding_size)
    
    for _ in  range(epochs):
        for i in tqdm(range(0, data.shape[0], batch_size), "Batch"):
            step_x = data[i: i + batch_size]
            result = trainer.train_step(step_x)
            for key, value in result.items():
                print(f"{key}: {value}", sep = " | ", end="")

        # Checkpoints
        trainer.checkpoint_manager.save()
        
        
    return trainer.vqvae

In [7]:
epochs = 2
n_train_size = 3

x_train = X[:n_train_size * n_batch]   
model = fit_vqvae(x_train, n_batch, 1)

Batch:   0%|          | 0/3 [00:52<?, ?it/s]


ResourceExhaustedError: {{function_node __wrapped__Pow_device_/job:localhost/replica:0/task:0/device:CPU:0}} OOM when allocating tensor with shape[1,44100,65535] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu [Op:Pow] name: 

In [None]:
vqvae = get_vqvae(input_shape, embedding_size, conv_filters, latent_dim)

# checkpoint_dir = "./checkpoints/Waveforms_with_VQVAE"
# def get_last_checkpoint():
#     pattern = r'checkpoint-(\d+)\.'
#     files = os.listdir(checkpoint_dir)
#     checkpoints = [int(re.match(pattern, file).group(1)) if re.match(pattern, file) else -1 for file in files if file.startswith("checkpoint")]
#     return max(checkpoints)

checkpoint = tf.train.Checkpoint(model=vqvae)
manager = tf.train.CheckpointManager(checkpoint, './checkpoints/Waveforms_with_VQVAE', checkpoint_name="checkpoint")

manager.restore_or_initialize()