In [1]:
%load_ext autoreload
%autoreload 2

In [58]:
import jax
import torch
import numpy
import jax

from moshi.moshi.modules.transformer import ProjectedTransformer
from moshi_jax.moshi_jax.modules.transformer import ProjectedTransformer as JAXTransformer


_seanet_kwargs = {
    "channels": 1,
    "dimension": 512,
    "causal": True,
    "n_filters": 64,
    "n_residual_layers": 1,
    "activation": "ELU",
    "compress": 2,
    "dilation_base": 2,
    "disable_norm_outer_blocks": 0,
    "kernel_size": 7,
    "residual_kernel_size": 3,
    "last_kernel_size": 3,
    # We train using weight_norm but then the weights are pre-processed for inference so
    # that we can use a normal convolution.
    "norm": "none",
    "pad_mode": "constant",
    "ratios": [8, 6, 5, 4],
    "true_skip": True,
}

_transformer_kwargs = {
    "d_model": _seanet_kwargs["dimension"],
    "num_heads": 8,
    "num_layers": 8,
    "causal": True,
    "layer_scale": 0.01,
    "context": 250,
    "conv_layout": True,
    "max_period": 10000,
    "gating": "none",
    "norm": "layer_norm",
    "positional_embedding": "rope",
    "dim_feedforward": 2048,
    "input_dimension": _seanet_kwargs["dimension"],
    "output_dimensions": [_seanet_kwargs["dimension"]],
}

device = torch.get_default_device()

transformer = ProjectedTransformer(
        device=device, **_transformer_kwargs
    )
jax_transformer = JAXTransformer(**_transformer_kwargs, key=jax.random.key(1))
their_params = {key: jax.numpy.array(numpy.array(value.detach())) for key, value in transformer.named_parameters()}

  their_params = {key: jax.numpy.array(numpy.array(value.detach())) for key, value in transformer.named_parameters()}


In [59]:
print(their_params.keys())
print(transformer.output_projs)
print(jax_transformer.output_projs)

dict_keys(['transformer.layers.0.self_attn.in_proj_weight', 'transformer.layers.0.self_attn.out_proj.weight', 'transformer.layers.0.norm1.weight', 'transformer.layers.0.norm1.bias', 'transformer.layers.0.norm2.weight', 'transformer.layers.0.norm2.bias', 'transformer.layers.0.linear1.weight', 'transformer.layers.0.linear2.weight', 'transformer.layers.0.layer_scale_1.scale', 'transformer.layers.0.layer_scale_2.scale', 'transformer.layers.1.self_attn.in_proj_weight', 'transformer.layers.1.self_attn.out_proj.weight', 'transformer.layers.1.norm1.weight', 'transformer.layers.1.norm1.bias', 'transformer.layers.1.norm2.weight', 'transformer.layers.1.norm2.bias', 'transformer.layers.1.linear1.weight', 'transformer.layers.1.linear2.weight', 'transformer.layers.1.layer_scale_1.scale', 'transformer.layers.1.layer_scale_2.scale', 'transformer.layers.2.self_attn.in_proj_weight', 'transformer.layers.2.self_attn.out_proj.weight', 'transformer.layers.2.norm1.weight', 'transformer.layers.2.norm1.bias', 

In [60]:
jax.tree_util.tree_map_with_path(lambda p, x: print(jax.tree_util.keystr(p)), jax_transformer)

.transformer.positional_embedding
.transformer.max_period
.transformer.positional_scale
.transformer.rope.max_period
.transformer.layers[0].linear1.weight
.transformer.layers[0].linear2.weight
.transformer.layers[0].self_attn.embed_dim
.transformer.layers[0].self_attn.causal
.transformer.layers[0].self_attn.context
.transformer.layers[0].self_attn.rope.max_period
.transformer.layers[0].self_attn.num_heads
.transformer.layers[0].self_attn.weights_per_step
.transformer.layers[0].self_attn.in_proj.weight
.transformer.layers[0].self_attn.out_proj.weight
.transformer.layers[0].norm1.weight
.transformer.layers[0].norm1.bias
.transformer.layers[0].norm2.weight
.transformer.layers[0].norm2.bias
.transformer.layers[0].skip_self_attn
.transformer.layers[0].activation
.transformer.layers[0].layer_scale_1.scale
.transformer.layers[0].layer_scale_1.channel_last
.transformer.layers[0].layer_scale_2.scale
.transformer.layers[0].layer_scale_2.channel_last
.transformer.layers[0].weights_per_step
.trans

ProjectedTransformer(
  transformer=StreamingTransformer(
    positional_embedding=None,
    max_period=None,
    positional_scale=None,
    betas=None,
    rope=RotaryEmbedding(max_period=None),
    layers=[
      StreamingTransformerLayer(
        gating=None,
        linear1=Linear(
          weight=None,
          bias=None,
          in_features=512,
          out_features=2048,
          use_bias=False
        ),
        linear2=Linear(
          weight=None,
          bias=None,
          in_features=2048,
          out_features=512,
          use_bias=False
        ),
        self_attn=StreamingMultiheadAttention(
          embed_dim=None,
          causal=None,
          context=None,
          rope=RotaryEmbedding(max_period=None),
          num_heads=None,
          weights_per_step=None,
          in_proj=Linear(
            weight=None,
            bias=None,
            in_features=512,
            out_features=1536,
            use_bias=False
          ),
          out_p

In [61]:
import jax
import jax.tree_util as jtu
import jax.numpy as jnp

def copy_weights(path, x):
    path = jtu.keystr(path)[1:]
    # if "[0].weight"
    
    if "layers[" in path:
        new_path = path.replace("[", ".").replace("]", "")

        if ".weight" in path and "weights" not in path:
            if "in_proj" in path:
                new_path = new_path.replace("in_proj.weight", "in_proj_weight")
            print(f"{path} {x.shape} {their_params[new_path].shape}")
            return their_params[new_path]
        elif "bias" in path:
            print(f"{path} {x.shape} {their_params[new_path].shape}")
            return their_params[new_path]
    return x 

    
jax_transformer = jtu.tree_map_with_path(copy_weights, jax_transformer)

import torch
our_x = jax.random.normal(jax.random.key(1), shape=(1, 512, 101))
their_x= torch.Tensor(numpy.array(our_x))

our_res = jax.vmap(jax_transformer)(our_x)
their_res = transformer(their_x)

print(our_res[0][0,0,:10])
print(their_res[0][0,0,:10])


transformer.layers[0].linear1.weight (2048, 512) (2048, 512)
transformer.layers[0].linear2.weight (512, 2048) (512, 2048)
transformer.layers[0].self_attn.in_proj.weight (1536, 512) (1536, 512)
transformer.layers[0].self_attn.out_proj.weight (512, 512) (512, 512)
transformer.layers[0].norm1.weight (512,) (512,)
transformer.layers[0].norm1.bias (512,) (512,)
transformer.layers[0].norm2.weight (512,) (512,)
transformer.layers[0].norm2.bias (512,) (512,)
transformer.layers[1].linear1.weight (2048, 512) (2048, 512)
transformer.layers[1].linear2.weight (512, 2048) (512, 2048)
transformer.layers[1].self_attn.in_proj.weight (1536, 512) (1536, 512)
transformer.layers[1].self_attn.out_proj.weight (512, 512) (512, 512)
transformer.layers[1].norm1.weight (512,) (512,)
transformer.layers[1].norm1.bias (512,) (512,)
transformer.layers[1].norm2.weight (512,) (512,)
transformer.layers[1].norm2.bias (512,) (512,)
transformer.layers[2].linear1.weight (2048, 512) (2048, 512)
transformer.layers[2].linear2

In [None]:
import jax
import jax.tree_util as jtu
import jax.numpy as jnp

dec_matches = [
    ("first_layer.conv.conv.weight", "model.0.conv.conv.weight"),
    ("first_layer.conv.conv.bias", "model.0.conv.conv.bias"),
    ("blocks[0][1].convtr.convtr.weight", "model.2.convtr.convtr.weight"),
    ("blocks[0][1].convtr.convtr.bias", "model.2.convtr.convtr.bias"),
    ("blocks[0][0][0].blocks[1].conv.conv.weight","model.3.block.3.conv.conv.weight"),
    ("blocks[0][0][0].blocks[1].conv.conv.bias","model.3.block.3.conv.conv.bias"),    ("blocks[0][0][0].blocks[0].conv.conv.weight","model.3.block.1.conv.conv.weight"),
    ("blocks[0][0][0].blocks[0].conv.conv.bias","model.3.block.1.conv.conv.bias"),
    ("blocks[1][1].convtr.convtr.weight", "model.5.convtr.convtr.weight"),
    ("blocks[1][1].convtr.convtr.bias", "model.5.convtr.convtr.bias"),
    ("blocks[1][0][0].blocks[1].conv.conv.weight","model.6.block.3.conv.conv.weight"),
    ("blocks[1][0][0].blocks[1].conv.conv.bias","model.6.block.3.conv.conv.bias"),    ("blocks[1][0][0].blocks[0].conv.conv.weight","model.6.block.1.conv.conv.weight"),
    ("blocks[1][0][0].blocks[0].conv.conv.bias","model.6.block.1.conv.conv.bias"),
    ("last_layer.conv.conv.weight", "model.8.conv.conv.weight"),
    ("last_layer.conv.conv.bias", "model.8.conv.conv.bias"),
]

enc_matches = [
    ("first_layer.conv.conv.weight", "model.0.conv.conv.weight"),
    ("first_layer.conv.conv.bias", "model.0.conv.conv.bias"),
    ("blocks[0][0][0].blocks[1].conv.conv.weight","model.1.block.3.conv.conv.weight"),
    ("blocks[0][0][0].blocks[1].conv.conv.bias","model.1.block.3.conv.conv.bias"),    ("blocks[0][0][0].blocks[0].conv.conv.weight","model.1.block.1.conv.conv.weight"),
    ("blocks[0][0][0].blocks[0].conv.conv.bias","model.1.block.1.conv.conv.bias"),
    ("blocks[0][1].conv.conv.weight", "model.3.conv.conv.weight"),
    ("blocks[0][1].conv.conv.bias", "model.3.conv.conv.bias"),
    ("blocks[1][0][0].blocks[1].conv.conv.weight","model.4.block.3.conv.conv.weight"),
    ("blocks[1][0][0].blocks[1].conv.conv.bias","model.4.block.3.conv.conv.bias"),    ("blocks[1][0][0].blocks[0].conv.conv.weight","model.4.block.1.conv.conv.weight"),
    ("blocks[1][0][0].blocks[0].conv.conv.bias","model.4.block.1.conv.conv.bias"),
    ("blocks[1][1].conv.conv.weight", "model.6.conv.conv.weight"),
    ("blocks[1][1].conv.conv.bias", "model.6.conv.conv.bias"),    
    ("last_layer.conv.conv.weight", "model.8.conv.conv.weight"),
    ("last_layer.conv.conv.bias", "model.8.conv.conv.bias"),
]
print(jax_decoder.blocks[0][1].convtr.convtr.weight.shape)

dec_matches = {k: v for k, v in dec_matches}
enc_matches = {k: v for k, v in enc_matches}

def copy_weights(path, x):
    path = jtu.keystr(path)[1:]
    if path in enc_matches.keys() and "weight" in path:
        return their_enc_params[enc_matches[path]]
    elif path in enc_matches.keys() and "bias" in path:
        return jnp.expand_dims(their_enc_params[enc_matches[path]], -1)
    return x 

def dec_copy_weights(path, x):
    # print(path)

    path = jtu.keystr(path)[1:]
    if path in dec_matches.keys() and "weight" in path:
        print(f"{path} {x.shape} {their_dec_params[dec_matches[path]].shape}")
        if "convtr" in path:
            return jnp.permute_dims(their_dec_params[dec_matches[path]], (1, 0, 2))
        return their_dec_params[dec_matches[path]]
    elif path in dec_matches.keys() and "bias" in path:
        print(path)
        return jnp.expand_dims(their_dec_params[dec_matches[path]], -1)
    # print(path)
    return x 

    
jax_encoder = jtu.tree_map_with_path(copy_weights, jax_encoder)
jax_decoder = jtu.tree_map_with_path(dec_copy_weights, jax_decoder)
print(jax_decoder.blocks[0][1].convtr.convtr.weight.shape)

import torch
our_x = jax.random.normal(jax.random.key(1), shape=(1, 1, 960))
their_x= torch.Tensor(numpy.array(our_x))

our_res = jax.vmap(jax_encoder)(our_x)
their_res = encoder(their_x)

print(our_res.shape)
print(their_res.shape)

their_res = decoder(their_res)
our_res = jax.vmap(jax_decoder)(our_res)

print(our_res.shape)
print(their_res.shape)


(128, 256, 16)
first_layer.conv.conv.weight (256, 512, 7) (256, 512, 7)
first_layer.conv.conv.bias
blocks[0][0][0].blocks[0].conv.conv.weight (64, 128, 3) (64, 128, 3)
blocks[0][0][0].blocks[0].conv.conv.bias
blocks[0][0][0].blocks[1].conv.conv.weight (128, 64, 1) (128, 64, 1)
blocks[0][0][0].blocks[1].conv.conv.bias
blocks[0][1].convtr.convtr.weight (128, 256, 16) (256, 128, 16)
blocks[0][1].convtr.convtr.bias
blocks[1][0][0].blocks[0].conv.conv.weight (32, 64, 3) (32, 64, 3)
blocks[1][0][0].blocks[0].conv.conv.bias
blocks[1][0][0].blocks[1].conv.conv.weight (64, 32, 1) (64, 32, 1)
blocks[1][0][0].blocks[1].conv.conv.bias
blocks[1][1].convtr.convtr.weight (64, 128, 12) (128, 64, 12)
blocks[1][1].convtr.convtr.bias
last_layer.conv.conv.weight (1, 64, 3) (1, 64, 3)
last_layer.conv.conv.bias
(128, 256, 16)


In [8]:
import jax
import numpy
import torch
import jax.numpy as jnp
from moshi.moshi.quantization.core_vq import _run_kmeans
from moshi_jax.moshi_jax.quantization.core_vq import _run_kmeans as jax_run_kmeans

our_x = jnp.concat([jnp.ones((20, 2)) * 2, jnp.ones((20, 2)), jnp.zeros((20, 2))])
x = torch.from_numpy(numpy.array(our_x))

print(_run_kmeans(x, 3))
print(jax_run_kmeans(our_x, 3, key=jax.random.key(2)))

Their mean size: torch.Size([3, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bins size: torch.Size([3])
Poute torch.Size([60, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bins size: torch.Size([3])
Poute torch.Size([60, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bins size: torch.Size([3])
Poute torch.Size([60, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bins size: torch.Size([3])
Poute torch.Size([60, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bins size: torch.Size([3])
Poute torch.Size([60, 2])
Their sample size: torch.Size([60, 2])
Their bucket size: torch.Size([60])
newmeans size : torch.Size([3, 2])
Their bin