In [1]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace

# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model     = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()
apriel_model.to(device).to(dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.68it/s]


AprielForCausalLM(
  (model): AprielModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (self_attn): AprielAttention(
          (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
    (ro

In [4]:
apriel_model.config.torch_dtype

torch.bfloat16

In [5]:
n_params = sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)

In [6]:
n_params/1e9

4.83207168

In [7]:
apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, 
                                    hidden_size=config.hidden_size,
                                    intermediate_size=config.intermediate_size,
                                    num_hidden_layers=config.num_hidden_layers,
                                    hidden_act=config.hidden_act,
                                    initializer_range=config.initializer_range,
                                    use_cache=config.use_cache,
                                    mlp_bias=config.mlp_bias,
                                    tie_word_embeddings=config.tie_word_embeddings,
                                    pad_token_id=config.pad_token_id,
                                    bos_token_id=config.bos_token_id,
                                    eos_token_id=config.eos_token_id,
                                    rms_norm_eps=config.rms_norm_eps)

In [8]:
apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)

In [9]:
apriel_ssm_config

AprielSSMConfig {
  "_attn_implementation_autoset": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "mlp_bias": false,
  "model_type": "apriel_ssm",
  "num_hidden_layers": 28,
  "rms_norm_eps": 1e-05,
  "ssm_cfg": {
    "activation": "identity",
    "bias": false,
    "chunk_size": 128,
    "d_inner": 4104,
    "d_state": 64,
    "expand": 1,
    "n_qk_heads": 24,
    "n_v_heads": 24
  },
  "tie_word_embeddings": false,
  "transformers_version": "4.48.1",
  "use_cache": true,
  "vocab_size": 131072
}

In [10]:
print("N params SSM:", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)


N params SSM: 5.660780512


# Load State dict into SSM

In [11]:

apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=11304, bias=False)
          (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)
          (act): Identity()
          (out_proj): Linear(in_features=4104, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fe

In [12]:
apriel_ssm.load_state_dict(apriel_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.wei

In [13]:

apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=11304, bias=False)
          (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)
          (act): Identity()
          (out_proj): Linear(in_features=4104, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fe

In [14]:
# apriel_ssm.state_dict()

### Save checkpoint

In [15]:
apriel_ssm.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm",
                            save_config=True)




In [60]:
apriel_ssm.model.layers[0].mixer.n_v_heads

24

In [10]:
apriel_ssm

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=12320, bias=False)
          (conv1d): Conv1d(8192, 8192, kernel_size=(4,), stride=(1,), padding=(3,), groups=8192)
          (act): Identity()
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
    (rotary_emb): AprielRotar

# Try a forward pass

In [51]:
input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)
batch_size = 1
max_length = 128
state = SimpleNamespace()
state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)
state.batch_size = batch_size
state.seqlen_offset = 0
static_inputs = {"inference_params": state,
        "input_ids": input_ids,
        "use_cache": True,
}

In [56]:
apriel_ssm.forward(**static_inputs)

CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4688, -1.6641,  0.4609,  ..., -7.1562, -3.7812, -5.9062],
         [-3.5000,  1.4297,  4.3125,  ..., -5.3438, -4.9375, -2.9844],
         [-3.1094,  0.7930,  2.2969,  ..., -3.1250, -4.1875, -2.1250],
         ...,
         [-5.3438, -3.0938, -3.9062,  ..., -4.9062, -3.0000, -3.9688],
         [-3.0625, -3.2188,  5.6562,  ..., -2.7812, -2.5938, -6.6562],
         [-1.8438, -1.7500,  5.9062,  ..., -3.7188, -2.1250, -0.8281]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>), all_hidden_states=(), last_hidden_state=tensor([[[ 1.2266,  0.5547, -1.1953,  ...,  0.1089, -2.5781,  0.6328],
         [-0.4395,  0.5938, -0.1562,  ..., -0.6719, -0.6367, -0.3086],
         [ 0.0077,  0.6680, -1.0703,  ..., -3.6875,  0.2207,  0.1299],
         ...,
         [-0.0703,  0.4551,  0.1104,  ...,  1.3438,  1.3984,  1.1641],
         [-0.0613,  1.9141, -0.5430,  ..., -1.0312, -0.6680,  0.0518],
         [-0.6172,  0.2148, -0.5977,  ..., -1.2734, -