In [2]:
import torch
from configuration_jamba import JambaConfig
from modeling_jamba import JambaMambaMixer, JambaMLP

from mamba import MambaConfig, MambaBlock
from jamba import JambaConfig as myJambaConfig, JambaMLP as myJambaMLP

# JambaMambaMixer = MambaBlock

In [None]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMambaMixer(config, 0)

config_mamba = MambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, inner_layernorms=True)
model = MambaBlock(config_mamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [None]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x), rtol=0.001)

# JambaMLP

In [None]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMLP(config)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaMLP(config_jamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [None]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x), model(x))

# JambaSparseMoeBlock

In [19]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaSparseMoeBlock

from jamba import JambaConfig as myJambaConfig, JambaSparseMoeBlock as myJambaSparseMoeBlock

In [25]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaSparseMoeBlock(config, num_experts=8, num_experts_per_tok=2)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaSparseMoeBlock(config_jamba, num_experts=8, num_experts_per_tok=2)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [26]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x)[0])

True

# JambaMambaDecoderLayer

In [None]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaMambaDecoderLayer

from jamba import JambaConfig as myJambaConfig, JambaMambaDecoderLayer as myJambaMambaDecoderLayer

In [None]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaMambaDecoderLayer(config, num_experts=1, layer_idx=0)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaMambaDecoderLayer(config_jamba, num_experts=1)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [None]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x)[0], rtol=0.001)

# JambaSdpaAttention

In [None]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaSdpaAttention

from jamba import JambaConfig as myJambaConfig, JambaSdpaAttention as myJambaSdpaAttention

In [None]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaSdpaAttention(config, layer_idx=0)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaSdpaAttention(config_jamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [None]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x))

# JambaAttentionDecoderLayer

In [None]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaAttentionDecoderLayer

from jamba import JambaConfig as myJambaConfig, JambaAttentionDecoderLayer as myJambaAttentionDecoderLayer

In [None]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, num_hidden_layers=4, use_mamba_kernels=False)
mixer = JambaAttentionDecoderLayer(config, num_experts=1, layer_idx=0)

config_jamba = myJambaConfig(d_model=64, n_layers=4, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaAttentionDecoderLayer(config_jamba, num_experts=1)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [None]:
x = torch.randn(10, 12, 64)
torch.allclose(mixer(x)[0], model(x)[0])

# JambaModel

In [3]:
from configuration_jamba import JambaConfig
from modeling_jamba import JambaModel

from jamba import JambaConfig as myJambaConfig, JambaModel as myJambaModel

In [27]:
config = JambaConfig(vocab_size=60, hidden_size=64, intermediate_size=128, use_mamba_kernels=False)
mixer = JambaModel(config)

config_jamba = myJambaConfig(vocab_size=60, d_model=64, n_layers=32, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaModel(config_jamba)

for name, param in mixer.named_parameters():
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [28]:
x = torch.randint(low=0, high=60, size=(10, 12))
torch.allclose(mixer(x).last_hidden_state, model(x), rtol=0.001)

True

# overall

In [5]:
import torch

from configuration_jamba import JambaConfig
from modeling_jamba import JambaForCausalLM

from jamba import JambaLMConfig as myJambaLMConfig, JambaLM as myJambaLM

In [13]:
config = JambaConfig(vocab_size=60, hidden_size=256, num_hidden_layers=8, intermediate_size=128, use_mamba_kernels=False)
mixer = JambaForCausalLM(config)

config_jamba = myJambaLMConfig(vocab_size=60, d_model=256, n_layers=8, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True)
model = myJambaLM(config_jamba)

for name, param in mixer.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")
    
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

In [24]:
x = torch.randint(low=0, high=60, size=(10, 12))
torch.allclose(mixer(x).logits, model(x), rtol=0.01)

True