In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import jax
import torch
import numpy
from huggingface_hub import hf_hub_download
import jax

from moshi.moshi.models.loaders import get_mimi
from moshi_jax.moshi_jax.quantization.vq import (
    SplitResidualVectorQuantizer as JAXQuantizer,
)

_seanet_kwargs = {
    "channels": 1,
    "dimension": 512,
    "causal": True,
    "n_filters": 64,
    "n_residual_layers": 1,
    "activation": "ELU",
    "compress": 2,
    "dilation_base": 2,
    "disable_norm_outer_blocks": 0,
    "kernel_size": 7,
    "residual_kernel_size": 3,
    "last_kernel_size": 3,
    # We train using weight_norm but then the weights are pre-processed for inference so
    # that we can use a normal convolution.
    "norm": "none",
    "pad_mode": "constant",
    "ratios": [8, 6, 5, 4],
    "true_skip": True,
}
_quantizer_kwargs = {
    "dimension": 256,
    "n_q": 32,
    "bins": 2048,
    "input_dimension": _seanet_kwargs["dimension"],
    "output_dimension": _seanet_kwargs["dimension"],
}

device = torch.get_default_device()
mimi_weight = hf_hub_download(
    "kyutai/moshiko-pytorch-bf16", "tokenizer-e351c8d8-checkpoint125.safetensors"
)
model = get_mimi(mimi_weight)

their_params = {
    key: jax.numpy.array(numpy.array(value.detach()))
    for key, value in model.quantizer.named_parameters()
}
print(their_params.keys())

  from .autonotebook import tqdm as notebook_tqdm


dict_keys(['rvq_first.input_proj.weight', 'rvq_first.output_proj.weight', 'rvq_rest.input_proj.weight', 'rvq_rest.output_proj.weight'])


In [None]:
jax_quantizer = JAXQuantizer(**_quantizer_kwargs, key=jax.random.key(1))

In [None]:
model.quantizer.n_q
jax_quantizer.n_q

2

In [None]:
our_x = jax.random.normal(jax.random.key(1), shape=(10, 128))
jax_quantizer.vq.layers[0]._codebook._quantize(our_x)

Array([433, 128, 392,   3, 300, 476, 283, 402, 300,  88], dtype=int32)

In [None]:
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy


def copy_weights(path, x):
    path = jtu.keystr(path)[1:]
    # if "[0].weight"
    if path in their_params.keys():
        print(path)
        return their_params[path]
    if path == "rvq_rest.vq.layers[0]._codebook.embedding":
        print(path)
        return jnp.array(
            numpy.array(model.quantizer.rvq_rest.vq.layers[0]._codebook.embedding)
        )
    if path == "rvq_first.vq.layers[0]._codebook.embedding":
        print(path)
        return jnp.array(
            numpy.array(model.quantizer.rvq_first.vq.layers[0]._codebook.embedding)
        )

    return x


jax_quantizer = jtu.tree_map_with_path(copy_weights, jax_quantizer)


our_x = jax.random.normal(jax.random.key(1), shape=(1, 512, 10))
their_x = torch.from_numpy(numpy.array(our_x))

their_result = model.quantizer(their_x, 10)
result = jax.vmap(jax_quantizer, in_axes=(0, None))(our_x, 10)

print(their_result)
print(result)

rvq_first.input_proj.weight
rvq_first.output_proj.weight
rvq_first.vq.layers[0]._codebook.embedding
rvq_rest.input_proj.weight
rvq_rest.output_proj.weight
rvq_rest.vq.layers[0]._codebook.embedding
T after input_proj: torch.Size([1, 256, 10])
T during resquant: torch.Size([1, 256, 10])
T shape pre input: torch.Size([1, 10, 256])
T shape pre out: torch.Size([10])
T shape post out: torch.Size([1, 10])
T after input_proj: torch.Size([1, 256, 10])
T during resquant: torch.Size([1, 256, 10])
T shape pre input: torch.Size([1, 10, 256])
T shape pre out: torch.Size([10])
T shape post out: torch.Size([1, 10])
T during resquant: torch.Size([1, 256, 10])
T shape pre input: torch.Size([1, 10, 256])
T shape pre out: torch.Size([10])
T shape post out: torch.Size([1, 10])
T during resquant: torch.Size([1, 256, 10])
T shape pre input: torch.Size([1, 10, 256])
T shape pre out: torch.Size([10])
T shape post out: torch.Size([1, 10])
T during resquant: torch.Size([1, 256, 10])
T shape pre input: torch.Size