In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import torch
import numpy
from huggingface_hub import hf_hub_download
import jax

from moshi.moshi.models.loaders import get_mimi
from moshi_jax.moshi_jax.quantization.vq import (
    SplitResidualVectorQuantizer as JAXQuantizer,
)

_seanet_kwargs = {
    "channels": 1,
    "dimension": 512,
    "causal": True,
    "n_filters": 64,
    "n_residual_layers": 1,
    "activation": "ELU",
    "compress": 2,
    "dilation_base": 2,
    "disable_norm_outer_blocks": 0,
    "kernel_size": 7,
    "residual_kernel_size": 3,
    "last_kernel_size": 3,
    # We train using weight_norm but then the weights are pre-processed for inference so
    # that we can use a normal convolution.
    "norm": "none",
    "pad_mode": "constant",
    "ratios": [8, 6, 5, 4],
    "true_skip": True,
}
_quantizer_kwargs = {
    "dimension": 256,
    "n_q": 8,
    "bins": 2048,
    "input_dimension": _seanet_kwargs["dimension"],
    "output_dimension": _seanet_kwargs["dimension"],
}

device = torch.get_default_device()
mimi_weight = hf_hub_download(
    "kyutai/moshiko-pytorch-bf16", "tokenizer-e351c8d8-checkpoint125.safetensors"
)
model = get_mimi(mimi_weight)
jax_quantizer = JAXQuantizer(**_quantizer_kwargs, key=jax.random.key(1))

their_params = {
    key: jax.numpy.array(numpy.array(value.detach()))
    for key, value in model.quantizer.named_parameters()
}
print(their_params.keys())

  from .autonotebook import tqdm as notebook_tqdm
  key: jax.numpy.array(numpy.array(value.detach()))


dict_keys(['rvq_first.input_proj.weight', 'rvq_first.output_proj.weight', 'rvq_rest.input_proj.weight', 'rvq_rest.output_proj.weight'])


In [4]:
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy


def copy_weights(path, x):
    path = jtu.keystr(path)[1:]
    # if "[0].weight"
    if path in their_params.keys():
        print(path)
        return their_params[path]
    if path == "rvq_first.vq.layers[0]._codebook.embedding":
        print(path)
        return jnp.array(
            numpy.array(model.quantizer.rvq_first.vq.layers[0]._codebook.embedding)
        )
    if ".vq.layers[" in path and "]._codebook.embedding" in path:
        idx =int( path.split("[")[1][0])
        print(path)
        return jnp.array(
            numpy.array(model.quantizer.rvq_rest.vq.layers[idx]._codebook.embedding)
        )
    # print(path)

    return x


jax_quantizer = jtu.tree_map_with_path(copy_weights, jax_quantizer)


our_x = jax.random.normal(jax.random.key(1), shape=(1, 512, 10))
their_x = torch.from_numpy(numpy.array(our_x))

their_result = model.quantizer(their_x, 10)
result = jax.vmap(jax_quantizer, in_axes=(0, None))(our_x, 10)

print(their_result)
print(result)

rvq_first.input_proj.weight
rvq_first.output_proj.weight
rvq_first.vq.layers[0]._codebook.embedding_sum
rvq_first.vq.layers[0]._codebook.embedding
rvq_rest.input_proj.weight
rvq_rest.output_proj.weight
rvq_rest.vq.layers[0]._codebook.embedding_sum
rvq_rest.vq.layers[0]._codebook.embedding
rvq_rest.vq.layers[1]._codebook.embedding_sum
rvq_rest.vq.layers[1]._codebook.embedding
rvq_rest.vq.layers[2]._codebook.embedding_sum
rvq_rest.vq.layers[2]._codebook.embedding
rvq_rest.vq.layers[3]._codebook.embedding_sum
rvq_rest.vq.layers[3]._codebook.embedding
rvq_rest.vq.layers[4]._codebook.embedding_sum
rvq_rest.vq.layers[4]._codebook.embedding
rvq_rest.vq.layers[5]._codebook.embedding_sum
rvq_rest.vq.layers[5]._codebook.embedding
rvq_rest.vq.layers[6]._codebook.embedding_sum
rvq_rest.vq.layers[6]._codebook.embedding


  numpy.array(model.quantizer.rvq_rest.vq.layers[idx]._codebook.embedding)
  numpy.array(model.quantizer.rvq_first.vq.layers[0]._codebook.embedding)


T aaa tensor([[ 10.5174,   3.9937,  -7.0157,  10.9582, -12.2908,   3.1525,   3.3307,
           2.8719,   1.2311, -11.9664],
        [  7.1068,  -6.7943,  12.1126,   2.1296,  -0.4890,  -5.1519, -17.5122,
          -6.2433,   9.3379,   3.9586],
        [  7.9820,   1.6699, -10.2319,   5.7777,   1.6705,   1.2606,  -2.4887,
         -11.0096, -10.4015,  20.2741]], grad_fn=<SliceBackward0>)
T shape pre input: torch.Size([1, 10, 256])
T shape pre out: torch.Size([10])
T shape post out: torch.Size([1, 10])
T after quant: tensor([[-0.3790, -0.1061,  1.2380,  ...,  0.1567, -0.9781, -0.6167],
        [ 0.0825,  0.3540,  0.1169,  ..., -0.2204,  0.5010,  0.6133],
        [-0.1008,  1.1796,  0.9707,  ...,  1.1538, -0.5821,  1.2122],
        ...,
        [ 0.1358,  0.0322,  0.0451,  ..., -0.4313,  0.7493,  0.0605],
        [ 0.6849, -0.6331, -1.1305,  ..., -0.0322, -0.5578,  0.2000],
        [ 0.3323,  0.6872,  0.6157,  ...,  1.2430,  0.6165, -0.7309]])
T aaa tensor([[-2.1185,  4.2606,  2.9183, -2.

In [5]:
print(jax_quantizer.rvq_first.input_proj.weight[0, :10])
print(model.quantizer.rvq_first.input_proj.weight[0, :10])

our_y = jax.vmap(jax_quantizer.rvq_first.input_proj)(our_x)
their_y = model.quantizer.rvq_first.input_proj(their_x)

print(our_y[0, :10])
print(their_y[0, :10])

[[-0.79974586]
 [-0.50698906]
 [-0.6074174 ]
 [ 0.05502451]
 [ 0.6204352 ]
 [ 0.47274065]
 [-0.42298052]
 [ 0.08264463]
 [ 0.14858195]
 [ 0.39576823]]
tensor([[-0.7997],
        [-0.5070],
        [-0.6074],
        [ 0.0550],
        [ 0.6204],
        [ 0.4727],
        [-0.4230],
        [ 0.0826],
        [ 0.1486],
        [ 0.3958]], grad_fn=<SliceBackward0>)
[[ 10.515468     3.9977965   -7.017041    10.952286   -12.287186
    3.1558828    3.328309     2.8704345    1.2309551  -11.96372   ]
 [  7.105886    -6.790003    12.112903     2.1280365   -0.4910035
   -5.1534443  -17.5139      -6.2409863    9.339567     3.956934  ]
 [  7.982543     1.6725277  -10.23349      5.780754     1.6714315
    1.2587681   -2.4856648  -11.00898    -10.402664    20.271336  ]
 [-13.358841   -12.218836     2.2199643  -10.001466    -4.9096975
   -6.069808     6.6894493  -17.991146    11.907308    -9.91803   ]
 [  6.605486    -3.243204    -5.680114     2.2847328   -3.0460129
    5.167364    -5.451668     4