We want to copy the weights for the DVAE from the XTTS model and check that our model spits out the same thing as a sanity check

We first download TTS, and the model checkpoint for XTTS.

In [None]:
!wget https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth
!git clone git@github.com:nguyenhoanganh2002/XTTSv2-Finetuning-for-New-Languages.git
!mv XTTSv2-Finetuning-for-New-Languages/* .

We should now be able to load the model as it is used in the XTTSv2.

In [None]:
from TTS.tts.layers.xtts.dvae import DiscreteVAE
import jax
import torch
import equinox as eqx

# Exported from .ipynb using python3 export.py (Exports all cells with the export tag)
from VQVAE import VQVAE

dvae_pretrained = "./dvae.pth"
dvae = DiscreteVAE(
            channels=80,
            normalization=None,
            positional_dims=1,
            num_tokens=1024,
            codebook_dim=512,
            hidden_dim=512,
            num_resnet_blocks=3,
            kernel_size=3,
            num_layers=2,
            use_transposed_convs=False,
        )

# Take all the params for the torch model. Print them and map them to our model.
        
torch_params = {name: param.detach().numpy() for name, param in dvae.named_parameters()}

#Grab the checky ones that don't have a name but that we need
torch_params["dvae.codebook.embed"] = dvae.codebook.embed
dvae.load_state_dict(torch.load(dvae_pretrained), strict=False)
print(list(dvae.decoder.state_dict()))

#There's probably a better way of doing this lol
torch_to_jax_keys = [
    ("encoder.conv1.weight", "encoder.0.0.weight"),
    ("encoder.conv2.weight", "encoder.1.0.weight"),
    ("encoder.res1.conv1.weight", "encoder.2.net.0.weight"),
    ("encoder.res1.conv2.weight", "encoder.2.net.2.weight"),
    ("encoder.res1.conv3.weight", "encoder.2.net.4.weight"),
    ("encoder.res2.conv1.weight", "encoder.3.net.0.weight"),
    ("encoder.res2.conv2.weight", "encoder.3.net.2.weight"),
    ("encoder.res2.conv3.weight", "encoder.3.net.4.weight"),
    ("encoder.res3.conv1.weight", "encoder.4.net.0.weight"),
    ("encoder.res3.conv2.weight", "encoder.4.net.2.weight"),
    ("encoder.res3.conv3.weight", "encoder.4.net.4.weight"),
    ("encoder.conv3.weight", "encoder.5.weight"),
    ("encoder.conv1.bias", "encoder.0.0.bias"),
    ("encoder.conv2.bias", "encoder.1.0.bias"),
    ("encoder.res1.conv1.bias", "encoder.2.net.0.bias"),
    ("encoder.res1.conv2.bias", "encoder.2.net.2.bias"),
    ("encoder.res1.conv3.bias", "encoder.2.net.4.bias"),
    ("encoder.res2.conv1.bias", "encoder.3.net.0.bias"),
    ("encoder.res2.conv2.bias", "encoder.3.net.2.bias"),
    ("encoder.res2.conv3.bias", "encoder.3.net.4.bias"),
    ("encoder.res3.conv1.bias", "encoder.4.net.0.bias"),
    ("encoder.res3.conv2.bias", "encoder.4.net.2.bias"),
    ("encoder.res3.conv3.bias", "encoder.4.net.4.bias"),
    ("encoder.conv3.bias", "encoder.5.bias"),
    
    ("quantizer.codebook_avg", "dvae.codebook.embed"),
    ("quantizer.codebook", "dvae.codebook.embed"),

    ("decoder.conv1.weight", "decoder.0.weight"),
    ("decoder.res1.conv1.weight", "decoder.1.net.0.weight"),
    ("decoder.res1.conv2.weight", "decoder.1.net.2.weight"),
    ("decoder.res1.conv3.weight", "decoder.1.net.4.weight"),
    ("decoder.res2.conv1.weight", "decoder.2.net.0.weight"),
    ("decoder.res2.conv2.weight", "decoder.2.net.2.weight"),
    ("decoder.res2.conv3.weight", "decoder.2.net.4.weight"),
    ("decoder.res3.conv1.weight", "decoder.3.net.0.weight"),
    ("decoder.res3.conv2.weight", "decoder.3.net.2.weight"),
    ("decoder.res3.conv3.weight", "decoder.3.net.4.weight"),
    ("decoder.conv2.conv.weight", "decoder.4.0.conv.weight"),
    ("decoder.conv3.conv.weight", "decoder.5.0.conv.weight"),
    ("decoder.conv4.weight", "decoder.6.weight"),
    ("decoder.conv1.bias", "decoder.0.bias"),
    ("decoder.res1.conv1.bias", "decoder.1.net.0.bias"),
    ("decoder.res1.conv2.bias", "decoder.1.net.2.bias"),
    ("decoder.res1.conv3.bias", "decoder.1.net.4.bias"),
    ("decoder.res2.conv1.bias", "decoder.2.net.0.bias"),
    ("decoder.res2.conv2.bias", "decoder.2.net.2.bias"),
    ("decoder.res2.conv3.bias", "decoder.2.net.4.bias"),
    ("decoder.res3.conv1.bias", "decoder.3.net.0.bias"),
    ("decoder.res3.conv2.bias", "decoder.3.net.2.bias"),
    ("decoder.res3.conv3.bias", "decoder.3.net.4.bias"),
    ("decoder.conv2.conv.bias", "decoder.4.0.conv.bias"),
    ("decoder.conv3.conv.bias", "decoder.5.0.conv.bias"),
    ("decoder.conv4.bias", "decoder.6.bias"),
]

# Initialize the JAX model parameters
key = jax.random.PRNGKey(1)
model = VQVAE(key=key)

# Function to update the JAX model parameters
def update_params(path, x):
    path_str = '.'.join([p.name for p in path])
    for jax_key,torch_key in torch_to_jax_keys:
        # print(path_str + " " + torch_key)
        if jax_key == path_str:
            print(path_str)
            if "bias" in jax_key:
                return jax.numpy.expand_dims(torch_params[torch_key], -1)
            return jax.numpy.array(torch_params[torch_key])
    return x

# Update the JAX model parameters
model = jax.tree_util.tree_map_with_path(update_params, model)

# Replace the encoder in the model with the updated encoder
# model = eqx.tree_at(lambda m: m, model, model)
print(model.quantizer.codebook.shape)
eqx.tree_serialise_leaves("./xttsvqvae.eqx", model)

Below are chaotic tests I made to check things worked, progressively going through the model.

In [None]:
import numpy as np

x = torch.ones((1, 80, 300))
x1 = jax.numpy.array(x.numpy())  # Convert PyTorch tensor to numpy array before converting to JAX array
torch.testing.assert_close(x, torch.from_numpy(np.array(x1)))

print(dvae.encoder[1][0])
print(model.encoder.conv1)
y = dvae.encoder[0](x)  # Apply the encoder to the input tensor
y = dvae.encoder[1][0](y)  # Apply the encoder to the input tensor
y1= jax.vmap(model.encoder.conv1)(x1)
y1= jax.vmap(jax.nn.relu)(y1)
y1= jax.vmap(model.encoder.conv2)(y1)
# y1 = jax.vmap(jax.nn.relu)(y1)

# Convert JAX array back to numpy array and then to PyTorch tensor for comparison
torch.testing.assert_close(y, torch.from_numpy(np.array(y1)))

In [None]:
import numpy as np
import jax.numpy as jnp

x = torch.ones((1, 80, 300))
x1 = jax.numpy.array(x.numpy(), dtype=jnp.float64)[
    0
]  # Convert PyTorch tensor to numpy array before converting to JAX array
# torch.testing.assert_close(x, torch.from_numpy(np.array(x1)))

input = dvae.encoder(x).permute((0, 2, 1))  # Apply the encoder to the input tensor
print(input.shape)
flatten = input.reshape(-1, 512)
print(flatten.shape)
a_squared = flatten.pow(2).sum(1, keepdim=True)
print(a_squared.shape)
print(f"Their codebook shape{dvae.codebook.embed.shape}")

b_squared = dvae.codebook.embed.pow(2).sum(0, keepdim=True)
print(b_squared.shape)
dist = a_squared - 2 * flatten @ dvae.codebook.embed + b_squared
soft_codes = -dist
_, embed_ind = soft_codes.max(1)
embed_onehot = torch.nn.functional.one_hot(embed_ind, dvae.codebook.n_embed).type(
    flatten.dtype
)
embed_ind = embed_ind.view(*input.shape[:-1])
y = dvae.codebook.embed_code(embed_ind)
print(y.shape)
y = y.permute((0, 2, 1))
print(y.shape)
print(f"My codebook shape{model.quantizer.codebook.shape}")
y1 = model.encoder(x1)
y1 = jax.numpy.permute_dims(y1, (1, 0))
print(y1.shape)
flatten = jax.numpy.reshape(y1, (-1, model.quantizer.D))
a_squared = jnp.sum(jnp.pow(flatten, 2), axis=-1, keepdims=True)
print(a_squared.shape)
b_squared = jnp.sum(jnp.pow(model.quantizer.codebook, 2), axis=0, keepdims=True)
print(b_squared.shape)
distance = a_squared + b_squared - 2 * jnp.matmul(flatten, model.quantizer.codebook)


codebook_indices = jnp.argmin(distance, axis=-1)
torch.testing.assert_close(
    embed_ind[0], torch.from_numpy(np.array(codebook_indices)).to(torch.int64)
)
# codebook_onehot = jax.nn.one_hot(codebook_indices, self.K)
z_q = model.quantizer.codebook.T[codebook_indices]
# Straight-through estimator
print(z_q.shape)
z_q = flatten + jax.lax.stop_gradient(z_q - flatten)
y1 = jax.numpy.permute_dims(z_q, (1, 0))

print(y.shape)
# y1, _ = jax.vmap(model.quantizer)(y1)
# y = dvae.encoder[1](y)  # Apply the encoder to the input tensor
# y1 = jax.vmap(jax.nn.relu)(y1)
# y1 = jax.vmap(model.encoder.conv2)(y1)
# y1 = jax.vmap(jax.nn.relu)(y1)

# y1 = jax.vmap(jax.nn.relu)(y1)

# Convert JAX array back to numpy array and then to PyTorch tensor for comparison
torch.testing.assert_close(y[0], torch.from_numpy(np.array(y1)))

In [None]:
import numpy as np
import jax.numpy as jnp

x = torch.ones((1, 80, 300))
x1 = jax.numpy.array(x.numpy(), dtype=jnp.float64)[
    0
]  # Convert PyTorch tensor to numpy array before converting to JAX array
# torch.testing.assert_close(x, torch.from_numpy(np.array(x1)))

input = dvae.encoder(x).permute((0, 2, 1))  # Apply the encoder to the input tensor
y = dvae.codebook(input)[0]
y = y.permute((0, 2, 1))
y = dvae.decoder(y)
y1 = model.encoder(x1)
y1, _ = model.quantizer(y1)
y1 = model.decoder(y1)
print(y.shape)
# y1, _ = jax.vmap(model.quantizer)(y1)
# y = dvae.encoder[1](y)  # Apply the encoder to the input tensor
# y1 = jax.vmap(jax.nn.relu)(y1)
# y1 = jax.vmap(model.encoder.conv2)(y1)
# y1 = jax.vmap(jax.nn.relu)(y1)

# y1 = jax.vmap(jax.nn.relu)(y1)

# Convert JAX array back to numpy array and then to PyTorch tensor for comparison
torch.testing.assert_close(y[0], torch.from_numpy(np.array(y1)))