In [2]:
import torch
from configuration_jamba import JambaConfig
from modeling_jamba import JambaMambaMixer, JambaMLP

from mamba import MambaConfig, MambaBlock
from jamba import JambaConfig as myJambaConfig, JambaMLP as myJambaMLP

# JambaMambaMixer = MambaBlock

In [47]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMambaMixer(config, 0)

config_mamba = MambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, inner_layernorms=True)
model = MambaBlock(config_mamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [48]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x), rtol=0.001)

True

# JambaMLP

In [69]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMLP(config)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaMLP(config_jamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [75]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x), model(x))

True

# JambaSparseMoeBlock

In [3]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaSparseMoeBlock

from jamba import JambaConfig as myJambaConfig, JambaSparseMoeBlock as myJambaSparseMoeBlock

In [12]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaSparseMoeBlock(config, num_experts=1, num_experts_per_tok=2)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaSparseMoeBlock(config_jamba, num_experts=1, num_experts_per_tok=2)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [17]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x)[0])

True

# JambaMambaDecoderLayer

In [3]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaMambaDecoderLayer

from jamba import JambaConfig as myJambaConfig, JambaMambaDecoderLayer as myJambaMambaDecoderLayer

In [4]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMambaDecoderLayer(config, num_experts=1, layer_idx=0)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaMambaDecoderLayer(config_jamba, num_experts=1)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [16]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x)[0], rtol=0.001)

True

# JambaSdpaAttention