In [37]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import optax
import equinox as eqx
import jax.numpy as jnp
from tensorboardX import SummaryWriter
from layers.VQVAE import VQVAE
from datasets import load_dataset
import datetime

import jax
%matplotlib inline
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

key1, key2 = jax.random.split(jax.random.key(2), 2)

model = VQVAE(key=key1)

optimizer = optax.adam(1e-4)
opt_state = optimizer.init(model)

writer = SummaryWriter(log_dir='./runs/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

epochs = 50
batch_size = 1
step = 0

dataset = load_dataset("blabble-io/libritts_r", "clean", streaming=True)

freq = 15
stride = int(22050 / freq)
def cut_up(samples):
    # print(int(len(sample["audio"]["array"])//stride))
    list = []
    for sample in samples["audio"]:
        for i in range(0, (int(len(sample["array"])//stride) -1)):
            list.append(sample["array"][i*stride:i*stride+stride])
    return {"audio": list}

dataset = dataset.map(cut_up, batched=True, remove_columns=['text_normalized', 'text_original', 'speaker_id', 'path', 'chapter_id', 'id'])

dataloader= dataset["train.clean.360"].batch(batch_size=batch_size)
fig, (ax1, ax2) = plt.subplots(2)
fig.show()


AttributeError: Cannot set attribute initial

In [None]:
def update_codebook_ema(model, updates: tuple, codebook_indices, key=None):
    avg_updates = jax.tree.map(lambda x: jax.numpy.mean(x, axis=0), updates)

    # Calculate which codes are too often used and yeet them. Prior is uniform.
    h = jnp.histogram(
        codebook_indices, bins=model.quantizer.K, range=(0, model.quantizer.K)
    )[0] / len(codebook_indices)
    part_that_should_be = 1 / model.quantizer.K
    mask = (h > 2 * part_that_should_be) | (h < 0.5 * part_that_should_be)
    rand_embed = (
        jax.random.normal(key, (model.quantizer.K, model.quantizer.D)) * mask[:, None]
    )
    avg_updates = (
        jnp.where(mask[:], 0, avg_updates[0]),
        jnp.where(mask[:, None], rand_embed, avg_updates[1]),
        jnp.where(mask[:, None], rand_embed, avg_updates[2]),
    )

    where = lambda q: (
        q.quantizer.cluster_size,
        q.quantizer.codebook_avg,
        q.quantizer.codebook,
    )

    # Update the codebook and other trackers.
    model = eqx.tree_at(where, model, avg_updates)
    return model


@eqx.filter_jit
@eqx.filter_value_and_grad(has_aux=True)
def calculate_losses(model, x):
    z_e, z_q, codebook_updates, y = jax.vmap(model)(x)
    y = y[:, :-2]
    # Are the inputs and outputs close?
    reconstruct_loss = jnp.mean(jnp.linalg.norm((x - y), ord=2, axis=(0, 1)))

    # Are the output vectors z_e close to the codes z_q ?
    commit_loss = jnp.mean(
        jnp.linalg.norm(z_e - jax.lax.stop_gradient(z_q), ord=2, axis=(0, 1))
    )
    # codebook = jnp.mean(codebook_updates[0][2], axis=0) #| hide_line
    # print(codebook.shape) #| hide_line
    # print(codebook) #| hide_line
    # print(jnp.mean(codebook, axis=-1).shape) #| hide_line
    # print(f"STDR: {jnp.std(codebook, axis=-1)}") #| hide_line
    # print(f"log: {jnp.log(jnp.clip(jnp.mean(codebook, axis=-1), min=1e-5)) }") #| hide_line
    # KL_loss = 0.5 * jnp.sum(jnp.mean(codebook, axis=-1)**2 + jnp.var(codebook, axis=-1)  #| hide_line- jnp.log(jnp.clip(jnp.std(codebook, axis=-1), min=1e-6)) - 1) #| hide_line

    total_loss = reconstruct_loss + commit_loss

    return total_loss, (reconstruct_loss, commit_loss, codebook_updates, y)


@eqx.filter_jit
def make_step(model, optimizer, opt_state, x, key):
    (total_loss, (reconstruct_loss, commit_loss, codebook_updates, y)), grads = (
        calculate_losses(model, x)
    )
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    model = update_codebook_ema(model, codebook_updates[0], codebook_updates[1], key)

    return (
        model,
        opt_state,
        total_loss,
        reconstruct_loss,
        commit_loss,
        codebook_updates,
        y,
    )

for epoch in range(epochs):
    # eqx.tree_serialise_leaves(f"checkpoints/{epoch}.eqx", model)
    for batch in dataloader:
        key1, grab = jax.random.split(key1)
        input = jax.numpy.array(batch["audio"])
        model, opt_state, total_loss, reconstruct_loss, commit_loss, codebook_updates, y = make_step(model, optimizer, opt_state, input, grab)

        # Log codebook updates to TensorBoard
        writer.add_scalar('Loss/Total', total_loss, step)
        writer.add_scalar('Loss/Reconstruct', reconstruct_loss, step)
        writer.add_scalar('Loss/Commit', commit_loss, step)
        step += 1
        writer.add_histogram('Codebook Updates/Code ids used', jnp.reshape(codebook_updates[1], -1), step)
        writer.add_histogram('Codebook Updates/Code means', jnp.mean(codebook_updates[0][2], axis=(0,2)), step)
        writer.add_histogram('Codebook Updates/Code stds', jnp.std(codebook_updates[0][2], axis=(0,2)), step)
        # if (step // batch_size) % 20 == 0:
        #     ax1.clear()
        #     ax2.clear()
        #     ax1.plot(batch["audio"][0])
        #     ax2.plot(y[0])
        #     display(fig)
        #     clear_output(wait=True)
    # plt.imshow(y[0])


KeyboardInterrupt: 

In [None]:
model.quantizer.codebook.shape

In [None]:

example = next(iter(dataloader2))

fig, (ax1, ax2) = plt.subplots(2)
fig.show()
print(example)
ax1.plot(example["audio"][0]["array"][3528:3528+3528])

display(fig)


In [None]:
from layers.VQVAE import VQVAE

import jax

model = VQVAE(key=jax.random.key(1))

x = jax.numpy.ones(1000)

y = model.encoder(x)

print(y.shape)

