In [1]:
# full (attn et mamba)

In [1]:
import torch
import torch.nn.functional as F

from configuration_jamba import JambaConfig
from modeling_jamba import JambaForCausalLM

from jamba import JambaLMConfig as myJambaLMConfig, JambaLM as myJambaLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = JambaConfig(vocab_size=60, hidden_size=256, num_hidden_layers=8, intermediate_size=128, use_mamba_kernels=False)
mixer = JambaForCausalLM(config)

config_jamba = myJambaLMConfig(vocab_size=60, d_model=256, n_layers=8, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True, tie_lm_weights=False)
model = myJambaLM(config_jamba)

for name, param in mixer.named_parameters():
    name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        name = "embedding.weight"
    
    if "final_layernorm" in name:
        name = name.replace("jamba.", "")
    
    counterpart_param = model.get_parameter(name)
    if counterpart_param is not None:
        counterpart_param.data.copy_(param.data)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config


In [10]:
data = torch.randint(0, 60, (1, 20))
x = data[:, :-1]
y = data[:, 1:]

torch.allclose(mixer(x).logits, model(x)[0], atol=0.001)

True

In [11]:
data = torch.randint(0, 60, (1, 20))
x = data[:, :-1]
y = data[:, 1:]

logits_hf = mixer(x).logits
loss_hf = F.cross_entropy(logits_hf.view(-1, logits_hf.size(-1)), y.view(-1))
loss_hf.backward()

logits = model(x)[0]
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
loss.backward()

In [12]:
gradient_close = True
for (name, param) in mixer.named_parameters():
    transformed_name = name.replace("model.", "jamba.")
    
    if "embed_tokens" in name:
        transformed_name = "embedding.weight"
    
    if "final_layernorm" in name:
        transformed_name = transformed_name.replace("jamba.", "")
    
    counterpart_param = model.get_parameter(transformed_name)
    if counterpart_param is not None and param.grad is not None and counterpart_param.grad is not None:
        if not torch.allclose(param.grad, counterpart_param.grad, atol=1e-6):
            gradient_close = False
            print(f"Gradient mismatch found at {name}")
            break

In [13]:
mixer.model.embed_tokens.weight.grad

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0080, -0.0081, -0.0014,  ...,  0.0044, -0.0035, -0.0176],
        [ 0.0051,  0.0018, -0.0030,  ...,  0.0145, -0.0063,  0.0110],
        ...,
        [-0.0059, -0.0111,  0.0469,  ..., -0.0121,  0.0169, -0.0184],
        [ 0.0163, -0.0131, -0.0047,  ..., -0.0113, -0.0115, -0.0163],
        [ 0.2363,  0.0969, -0.0448,  ...,  0.1178,  0.0344, -0.0113]])

In [14]:
model.embedding.weight.grad

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0080, -0.0081, -0.0014,  ...,  0.0044, -0.0035, -0.0176],
        [ 0.0051,  0.0018, -0.0030,  ...,  0.0145, -0.0063,  0.0110],
        ...,
        [-0.0059, -0.0111,  0.0469,  ..., -0.0121,  0.0169, -0.0184],
        [ 0.0163, -0.0131, -0.0047,  ..., -0.0113, -0.0115, -0.0163],
        [ 0.2363,  0.0969, -0.0448,  ...,  0.1178,  0.0344, -0.0113]])

In [17]:
config_jamba = myJambaLMConfig(vocab_size=60, d_model=256, n_layers=6, rms_norm_eps=1e-6, mlp_size=128, inner_layernorms=True, tie_lm_weights=False,
                               attn_layer_offset=2, attn_layer_period=3)
model = myJambaLM(config_jamba)

In [18]:
model

JambaLM(
  (embedding): Embedding(60, 256, padding_idx=0)
  (jamba): Jamba(
    (layers): ModuleList(
      (0): MambaLayer(
        (mamba): MambaBlock(
          (in_proj): Linear(in_features=256, out_features=1024, bias=False)
          (conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
          (x_proj): Linear(in_features=512, out_features=48, bias=False)
          (dt_proj): Linear(in_features=16, out_features=512, bias=True)
          (out_proj): Linear(in_features=512, out_features=256, bias=False)
          (dt_layernorm): RMSNorm()
          (B_layernorm): RMSNorm()
          (C_layernorm): RMSNorm()
        )
        (moe): SparseMoEBlock(
          (experts): ModuleList(
            (0): MLP(
              (gate_proj): Linear(in_features=256, out_features=128, bias=False)
              (down_proj): Linear(in_features=128, out_features=256, bias=False)
              (up_proj): Linear(in_features=256, out_features=128, bias=False)
           