In [11]:
import keras
import tensorflow as tf
import scipy.io.wavfile as wavfile
import numpy as np
from tqdm import tqdm
import re
import os

First, let's design the VQ VAE layers, since this is not a part of Tensorflow.

In [12]:

class VectorQuantizer(keras.layers.Layer):
    def __init__(self, num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings

        # The `beta` parameter is best kept between [0.25, 2] as per the paper.
        self.beta = beta

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)

        # Reshape the quantized values back to the original input shape
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(self.beta * commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs**2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings**2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


Now that the VQ VAE is defined, we need to define our encoder and decoder. This can be any model for encoding and decoding. OpenAI's JukeBox project uses noncausal 1-D dilated convolutions, interleaved with downsampling and upsampling 1-D convolutions. So let's use this approach!

In [13]:
def get_encoder(input_shape, conv_filters, latent_dim):
    encoder_inputs = keras.Input(shape=input_shape)
    
    x = keras.layers.Conv1D(conv_filters, 3, activation="relu", strides=2, padding="same")(
        encoder_inputs
    )
    
    x = keras.layers.Conv1D(2 * conv_filters, 3, activation="relu", strides=2, padding="same")(x)
    
    encoder_outputs = keras.layers.Conv1D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")


def get_decoder(input_shape, conv_filters, latent_dim):
    latent_inputs = keras.Input(shape=get_encoder(input_shape, conv_filters, latent_dim).output.shape[1:])
    
    x = keras.layers.Conv1DTranspose(
        2 * conv_filters, 3, activation="relu", strides=2, padding="same"
    )(latent_inputs)
    
    x = keras.layers.Conv1DTranspose(
        conv_filters, 3, activation="relu", strides=2, padding="same"
    )(x)
    
    decoder_outputs = keras.layers.Conv1DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")

def get_vqvae(input_shape, num_embeddings, conv_filters, latent_dim):
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    inputs = keras.Input(shape=input_shape)
    
    encoder = get_encoder(input_shape, conv_filters, latent_dim)
    decoder = get_decoder(input_shape, conv_filters, latent_dim)
    
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name="vq_vae")

Let's load from raw wav data and test our network's ability to train.

In [14]:


embedding_size = 2 ** 16 - 1

def load_song_file(song_file: str):
    rate, song_data = wavfile.read(song_file)

    mono_data = song_data

    if song_data.shape[1] == 2:
        mono_data = np.average(song_data, axis=1)
        
    audio = (mono_data / np.max(mono_data))
    audio = (audio * embedding_size) - 2 ** 15
    audio = audio.astype(np.int16)
    return rate, audio
        
        
def get_training_sequences(data, rate, chunk_duration):
    chunk_size = int(rate * chunk_duration)
    
    Xs = []
    for i in range(0, len(data), chunk_size):
        chunk = data[i:i + chunk_size]
        chunk = np.pad(chunk, (0, chunk_size - len(chunk)), mode='constant')
        chunk = tf.one_hot(chunk, depth=embedding_size)
        Xs.append(chunk)
        # Ys.append(chunk[-1])
        
        
    X = np.array(Xs)
    # Y = np.array(Ys)
    
    
    return X

sample_song = "data/Wavs/tvari-tokyo-cafe-159065.wav"
rate, data = load_song_file(sample_song)

test_duration = 10
sequence_duration = 0.2
set_samples = int(sequence_duration * rate)
data = data[:set_samples * test_duration]


X = get_training_sequences(data, rate, sequence_duration)

print(X.shape)
print(set_samples)

        

(10, 8820, 65535)
8820


In [15]:
conv_filters = 32
latent_dim = 32
n_batch = 32

vaqae = get_vqvae((set_samples, embedding_size), embedding_size, conv_filters, latent_dim)

optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss = keras.losses.MeanSquaredError()

vaqae.compile(optimizer=optimizer, loss=loss, metrics = ["accuracy"])


vaqae.summary()

# epochs = 1
# x_train = X[:n_train_size]
# vaqae.fit(x_train, x_train, batch_size=n_batches, epochs=epochs)

Model: "vq_vae"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_13 (InputLayer)       [(None, 8820, 65535)]     0         
                                                                 
 encoder (Functional)        (None, 2205, 32)          6299680   
                                                                 
 vector_quantizer (VectorQu  (None, 2205, 32)          2097120   
 antizer)                                                        
                                                                 
 decoder (Functional)        (None, 8820, 1)           12481     
                                                                 
Total params: 8409281 (32.08 MB)
Trainable params: 8409281 (32.08 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [16]:

class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, input_shape, conv_filters = 32, latent_dim=32, num_embeddings=128, **kwargs):
        super().__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings
        # self.input_shape = input_shape
        # self.conv_filters = conv_filters

        self.vqvae = get_vqvae(input_shape, self.num_embeddings, conv_filters, self.latent_dim)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")
        self.optimizer = keras.optimizers.Adam(learning_rate=1e-3)
        
        checkpoint = tf.train.Checkpoint(optimizer=self.optimizer, model=self.vqvae)
        self.checkpoint_manager = tf.train.CheckpointManager(checkpoint, './checkpoints/Waveforms_with_VQVAE', checkpoint_name="checkpoint", max_to_keep=5)

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,
        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # Outputs from the VQ-VAE.
            reconstructions = self.vqvae(x)

            # Calculate the losses.
            reconstruction_loss = (
                tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
            )
            total_loss = reconstruction_loss + sum(self.vqvae.losses)

        # Backpropagation.
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Loss tracking.
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

        # Log results.
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }
        
        
conv_filters = 32
latent_dim = 32
n_batch = 1
input_shape = (set_samples, embedding_size)


# x_train = X[:n_train_size * n_batch]

# trainer = VQVAETrainer(1, input_shape, conv_filters, latent_dim, embedding_size)

# for epoch in  tqdm(range(epochs), "Epoch"):
#     for i in tqdm(range(0, x_train.shape[0], n_batch), "Batch"):
        
#         step_x = np.array(x_train[i:n_batch])
#         result = trainer.train_step(step_x)
        
#     # Checkpoints
#     trainer.checkpoint_manager.save()
    
    
    
def fit_vqvae(data, batch_size, epochs):
    trainer = VQVAETrainer(0.2, input_shape, conv_filters, latent_dim, embedding_size)
    
    for _ in  range(epochs):
        for i in tqdm(range(0, data.shape[0], batch_size), "Batch"):
            step_x = data[i: i + batch_size]
            result = trainer.train_step(step_x)
            for key, value in result.items():
                print(f"{key}: {value}", sep = " | ", end="")

        # Checkpoints
        trainer.checkpoint_manager.save()
        
        
    return trainer.vqvae

In [17]:
epochs = 2
n_train_size = 3

x_train = X[:n_train_size * n_batch]   
model = fit_vqvae(x_train, n_batch, 1)

Batch:  33%|███▎      | 1/3 [00:06<00:12,  6.12s/it]

loss: 0.0005145255126990378reconstruction_loss: 9.880522702587768e-05vqvae_loss: 0.00041572030750103295

Batch:  67%|██████▋   | 2/3 [00:11<00:05,  5.94s/it]

loss: 0.0005372579907998443reconstruction_loss: 0.00013322987069841474vqvae_loss: 0.0004040281055495143

Batch: 100%|██████████| 3/3 [00:17<00:00,  5.90s/it]

loss: 0.0005027667502872646reconstruction_loss: 0.00011422794341342524vqvae_loss: 0.0003885387850459665




In [18]:
vqvae = get_vqvae(input_shape, embedding_size, conv_filters, latent_dim)

checkpoint_dir = "./checkpoints/Waveforms_with_VQVAE"
def get_last_checkpoint():
    pattern = r'checkpoint-(\d+)\.'
    files = os.listdir(checkpoint_dir)
    checkpoints = [int(re.match(pattern, file).group(1)) if re.match(pattern, file) else -1 for file in files if file.startswith("checkpoint")]
    return max(checkpoints)

    

vqvae.load_weights(f"{checkpoint_dir}/checkpoint-{get_last_checkpoint()}")

AssertionError: Nothing except the root object matched a checkpointed value. Typically this means that the checkpoint does not match the Python program. The following objects have no matching checkpointed value: [<tf.Variable 'conv1d_30/kernel:0' shape=(3, 65535, 32) dtype=float32, numpy=
array([[[ 3.1859837e-03,  2.4147369e-03, -1.6694725e-03, ...,
         -1.8291040e-03, -2.6177289e-04,  1.7357273e-03],
        [-2.5135255e-03, -4.5673917e-03, -2.1778131e-03, ...,
         -3.6310772e-03,  5.4083904e-03, -1.4033807e-03],
        [-3.5876790e-03,  7.6853437e-04, -5.6006107e-04, ...,
         -2.2224598e-03, -1.5791086e-03,  5.4973783e-03],
        ...,
        [ 2.3595206e-03,  3.8966136e-03,  1.4242004e-03, ...,
          7.7969674e-04, -3.9009377e-05,  3.9720479e-03],
        [-3.5629433e-03, -5.4414067e-03,  4.8809741e-03, ...,
          4.4001173e-03, -2.8449520e-03,  5.4237749e-03],
        [ 5.4534543e-03,  4.5976443e-03, -4.3144114e-03, ...,
          4.9228929e-03, -2.7507984e-03,  3.8739070e-03]],

       [[-5.1633692e-03, -3.7718844e-03, -1.3442561e-03, ...,
         -4.6979450e-04,  4.6077194e-03,  4.4519287e-03],
        [-3.7826675e-03,  1.6576871e-03, -2.5718918e-03, ...,
          4.7918493e-03, -2.2136651e-03, -4.2763511e-03],
        [-5.1356563e-03, -1.1466434e-03, -4.8627807e-03, ...,
         -1.7672973e-03,  3.7822593e-04,  5.2022561e-03],
        ...,
        [-4.2552473e-03,  2.9366547e-03, -4.1953428e-03, ...,
          3.6285780e-03, -5.3478023e-03, -2.0406761e-03],
        [ 3.3119433e-03, -3.6730669e-03,  1.5852344e-03, ...,
          5.4777376e-03, -5.5082953e-03, -5.0519528e-03],
        [-4.1364776e-03, -5.4222108e-03,  1.6961582e-03, ...,
          4.9377559e-04, -5.1352852e-03, -4.3449025e-03]],

       [[-9.5041050e-04,  1.2265798e-03, -1.9765780e-03, ...,
          5.2515129e-03,  4.1370820e-03,  4.5987722e-03],
        [ 6.6309096e-04, -3.5401920e-03,  1.1831312e-03, ...,
         -3.2820432e-03, -3.3968056e-03,  3.1397166e-03],
        [-3.3389109e-03,  1.9180365e-03, -2.9727148e-03, ...,
         -5.3149369e-03, -3.3569559e-03, -1.0492639e-03],
        ...,
        [ 4.5091109e-03, -2.0903607e-03, -1.6199972e-03, ...,
          1.3179849e-03, -8.3639100e-04, -3.6963606e-03],
        [ 2.5240779e-03, -1.0162778e-04,  1.3731290e-03, ...,
         -2.7470719e-03, -5.3463685e-03, -1.8320605e-04],
        [ 4.7342153e-03, -3.6527212e-03,  4.5936638e-03, ...,
         -5.1843957e-03,  8.0442755e-04, -1.8358382e-03]]], dtype=float32)>, <tf.Variable 'conv1d_transpose_17/kernel:0' shape=(3, 1, 32) dtype=float32, numpy=
array([[[-0.08154388, -0.01174907,  0.113837  ,  0.03406376,
         -0.17679563, -0.1865876 , -0.21403912, -0.05896947,
          0.12011567,  0.15494111,  0.08926427, -0.15880921,
         -0.14394802, -0.19183466,  0.16731048, -0.16774364,
          0.06706849,  0.10767704,  0.21678862,  0.22105828,
          0.19796467,  0.09410751, -0.15758783, -0.03965096,
         -0.12124408, -0.03522627, -0.20542282, -0.11443803,
         -0.04697616,  0.12317649, -0.24587242,  0.00504264]],

       [[-0.07125717, -0.05961129,  0.23443967, -0.20154591,
         -0.09992261,  0.24386913,  0.18912402,  0.1030758 ,
         -0.08639428, -0.19838227,  0.23030192,  0.07771915,
         -0.08384305,  0.11614558,  0.12786394,  0.01417083,
          0.1395801 , -0.05985029, -0.16185945, -0.15315463,
          0.01897496,  0.16857064, -0.04549119,  0.14801615,
         -0.12777439, -0.14883512, -0.13611764, -0.17222449,
         -0.08062865, -0.1225164 , -0.24166602, -0.10236356]],

       [[ 0.06882024,  0.17553851,  0.16143638, -0.02033697,
         -0.05426991, -0.12454659,  0.15727141,  0.1670284 ,
         -0.01172918, -0.08040185, -0.20640142,  0.18147126,
         -0.14146501,  0.10026082,  0.18389305, -0.00745168,
          0.10735041,  0.24109039, -0.20862548, -0.00157049,
          0.03616625, -0.20623426,  0.03829193,  0.11895201,
          0.2397081 ,  0.01820338, -0.04466254,  0.15929642,
          0.0692524 , -0.1493024 ,  0.01569095, -0.20330387]]],
      dtype=float32)>, <tf.Variable 'conv1d_31/kernel:0' shape=(3, 32, 64) dtype=float32, numpy=
array([[[ 0.0134176 ,  0.0315565 , -0.12363406, ..., -0.10226497,
         -0.01555102,  0.06672172],
        [-0.03083324, -0.08332602, -0.06957403, ..., -0.05285521,
         -0.01849659,  0.00485657],
        [ 0.04895036, -0.0934306 ,  0.01247899, ..., -0.06719256,
          0.00969177,  0.1022892 ],
        ...,
        [-0.06584523,  0.02753694,  0.10897338, ...,  0.05134931,
          0.06056477, -0.03066645],
        [ 0.12302774, -0.12808466,  0.08843504, ...,  0.13798133,
          0.07368718, -0.08908799],
        [ 0.0549835 ,  0.05663067,  0.06186938, ...,  0.1141268 ,
          0.06791912, -0.10422354]],

       [[-0.09315658,  0.10344137,  0.06282596, ..., -0.0971553 ,
         -0.06323699,  0.0111127 ],
        [ 0.02285658,  0.09423524, -0.09941649, ...,  0.07447423,
          0.0666094 ,  0.13799018],
        [-0.09526642,  0.11759356,  0.08069433, ...,  0.0256554 ,
          0.02175219, -0.07469719],
        ...,
        [-0.1075101 , -0.06870235,  0.0763894 , ...,  0.0734691 ,
          0.10942778, -0.11182983],
        [-0.07507181, -0.08195907,  0.03617947, ..., -0.05581746,
         -0.0361846 ,  0.02509099],
        [ 0.02645347, -0.06087837,  0.10308607, ..., -0.11828067,
          0.11784258,  0.02180302]],

       [[ 0.05438672,  0.12130535, -0.05720047, ...,  0.10046013,
          0.12323436,  0.06498258],
        [-0.07752909,  0.01212832, -0.08583842, ..., -0.08754234,
         -0.03217937, -0.13251747],
        [-0.14246655, -0.06036015,  0.04756463, ...,  0.13602024,
          0.11706239,  0.11196637],
        ...,
        [-0.04782169,  0.00249688, -0.02427036, ..., -0.09003954,
          0.00874738, -0.12035149],
        [ 0.01497085, -0.13379312,  0.13513541, ..., -0.07636353,
         -0.08510918,  0.03023708],
        [-0.04170448, -0.03197953, -0.11718769, ...,  0.06908754,
          0.07115327, -0.10701091]]], dtype=float32)>, <tf.Variable 'conv1d_transpose_16/kernel:0' shape=(3, 32, 64) dtype=float32, numpy=
array([[[ 0.10989046,  0.13917348, -0.07338534, ...,  0.03600191,
         -0.06865896,  0.13270265],
        [ 0.00683445,  0.03824528, -0.02396804, ...,  0.0871051 ,
          0.07592838, -0.08140943],
        [-0.0040115 ,  0.05761859,  0.06494844, ..., -0.12997048,
         -0.09478013, -0.11157012],
        ...,
        [-0.13186727,  0.13046455, -0.09267844, ..., -0.13653176,
         -0.12463537, -0.11212747],
        [-0.07517102,  0.03225759, -0.11844795, ..., -0.11330314,
         -0.13146423, -0.04276993],
        [-0.05563015, -0.12825204, -0.1267446 , ..., -0.14147122,
          0.02683689,  0.14150465]],

       [[ 0.09893382, -0.05116513,  0.06521076, ...,  0.11298025,
          0.08167055, -0.0201943 ],
        [-0.13439348, -0.11682215,  0.08582625, ..., -0.05620635,
         -0.13956149,  0.08771674],
        [-0.03171603, -0.0502078 , -0.10745776, ..., -0.08051897,
         -0.11233965, -0.07662851],
        ...,
        [-0.12143491, -0.03563544,  0.12245643, ..., -0.01109846,
          0.14301443,  0.01297021],
        [-0.12052335, -0.03345354, -0.0761056 , ...,  0.07538617,
         -0.05845602,  0.12099433],
        [ 0.06876738,  0.05425681, -0.11062171, ...,  0.05871449,
          0.06493171,  0.12731147]],

       [[ 0.00553106, -0.10793103, -0.01197037, ..., -0.09799349,
         -0.08222825, -0.14154652],
        [ 0.06075025, -0.07388839,  0.10538632, ...,  0.03242247,
         -0.13218717,  0.02307361],
        [-0.09600753,  0.10775444,  0.06762795, ...,  0.03718959,
          0.13840398,  0.14004773],
        ...,
        [-0.03155743,  0.07768556, -0.06124976, ..., -0.03593597,
          0.01193452, -0.13682048],
        [-0.136408  , -0.1185525 , -0.05642013, ...,  0.06110801,
         -0.0548519 , -0.05053971],
        [ 0.11730891, -0.12683678, -0.06509806, ...,  0.09341979,
         -0.11208169,  0.10363081]]], dtype=float32)>, <tf.Variable 'conv1d_31/bias:0' shape=(64,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'conv1d_32/kernel:0' shape=(1, 64, 32) dtype=float32, numpy=
array([[[ 0.18617207, -0.1190151 , -0.15213913, ...,  0.06268585,
         -0.1949541 , -0.07893205],
        [-0.11320299,  0.00418311, -0.08779937, ...,  0.24682456,
         -0.1542651 , -0.11088341],
        [ 0.19361621, -0.08268344,  0.13770497, ...,  0.03348798,
          0.20348448,  0.19699764],
        ...,
        [ 0.08504218,  0.1413306 ,  0.10230482, ..., -0.0538739 ,
          0.24543566, -0.21085972],
        [ 0.06251132, -0.12058491, -0.07281327, ..., -0.01670623,
         -0.04156226,  0.12352508],
        [ 0.05886728, -0.0575155 ,  0.0089559 , ..., -0.2275474 ,
         -0.22057462,  0.13682663]]], dtype=float32)>, <tf.Variable 'conv1d_transpose_16/bias:0' shape=(32,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)>, <tf.Variable 'conv1d_30/bias:0' shape=(32,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)>, <tf.Variable 'conv1d_transpose_15/kernel:0' shape=(3, 64, 32) dtype=float32, numpy=
array([[[ 0.01743081, -0.09493991,  0.02397971, ...,  0.11834547,
          0.01319839,  0.06897981],
        [ 0.0740373 , -0.07643043, -0.0499955 , ..., -0.10451199,
         -0.14393039, -0.08092177],
        [ 0.07815193, -0.0791543 , -0.13301928, ..., -0.10218558,
          0.14228082, -0.06153672],
        ...,
        [-0.05830385, -0.07124418,  0.00193974, ...,  0.13897005,
         -0.01983765,  0.02357529],
        [ 0.13166717,  0.12548348,  0.04546499, ..., -0.04738554,
          0.1294676 ,  0.04866153],
        [ 0.04216658, -0.04168218,  0.02755192, ..., -0.07563907,
          0.03093159, -0.02661341]],

       [[-0.09514236, -0.00552785, -0.031021  , ...,  0.07660156,
         -0.078469  , -0.02855392],
        [-0.02843853, -0.12962979, -0.09942544, ..., -0.10487852,
         -0.08519343,  0.04152319],
        [ 0.04293796, -0.040635  , -0.05006791, ...,  0.10306762,
         -0.04117142,  0.137171  ],
        ...,
        [ 0.00234085,  0.0419549 ,  0.10103545, ..., -0.00632438,
         -0.08730933,  0.13407603],
        [ 0.12364304,  0.08405602,  0.08143851, ...,  0.03404111,
          0.1011124 , -0.03516564],
        [ 0.0539683 , -0.02334692, -0.04778448, ..., -0.0381081 ,
         -0.07651656,  0.12179554]],

       [[-0.07981309,  0.06849322, -0.08983257, ..., -0.09370033,
          0.1135824 , -0.11707897],
        [-0.06648359, -0.12816532, -0.0341942 , ..., -0.04133203,
          0.12775058, -0.03112596],
        [-0.10607833, -0.09784803,  0.08413465, ...,  0.08753356,
         -0.12152156,  0.11904055],
        ...,
        [-0.14330849, -0.03515641, -0.04583325, ...,  0.13622731,
         -0.10564768, -0.09196968],
        [ 0.02919176, -0.10115406,  0.12492681, ..., -0.11141815,
         -0.08117295, -0.09718142],
        [-0.06592672,  0.02346951, -0.06014579, ...,  0.05301079,
         -0.06997462,  0.01500015]]], dtype=float32)>, <tf.Variable 'conv1d_transpose_15/bias:0' shape=(64,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)>, <tf.Variable 'conv1d_transpose_17/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>, <tf.Variable 'embeddings_vqvae:0' shape=(32, 65535) dtype=float32, numpy=
array([[ 0.02750692,  0.00364138,  0.01445637, ..., -0.00656211,
         0.01712495,  0.04033209],
       [-0.04390422, -0.00263307, -0.01812875, ..., -0.01264022,
         0.04101891,  0.00553171],
       [-0.03456918,  0.0073159 , -0.0379118 , ...,  0.00879203,
        -0.03785807, -0.02398562],
       ...,
       [-0.00478645,  0.03579688,  0.042984  , ...,  0.02592951,
         0.00745406, -0.03788954],
       [ 0.01733616, -0.00664086,  0.02838535, ...,  0.0401086 ,
        -0.02424562, -0.00396083],
       [ 0.02078224, -0.04103266, -0.01151044, ...,  0.04186524,
        -0.0110812 , -0.02230898]], dtype=float32)>, <tf.Variable 'conv1d_32/bias:0' shape=(32,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)>]