In [1]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace

# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


  from .autonotebook import tqdm as notebook_tqdm


#### Compare trained SSM with lr scalors of 0 on everythign except of mixers, with innitial checkpoint

In [3]:
checkpoint_base= "/mnt/checkpoints/ssm/apriel_ssm_instruct_init_mambainlama"
checkpoint_trained = "/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/ssmins_chillmlp-rand_15bteacher-bs768-lr0.0001-sl4096_ti2000_lm2/export/apriel_ssm/500"


In [4]:
model_base = AprielSSMForCausalLM.from_pretrained(checkpoint_base, torch_dtype=torch.bfloat16, trust_remote_code=True)
model_trained = AprielSSMForCausalLM.from_pretrained(checkpoint_trained, torch_dtype=torch.bfloat16, trust_remote_code=True)



The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  3.29it/s]
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.49s/it]


In [17]:
model_base.model.layers[5].mlp.down_proj.weight.sum()

tensor(9.8750, dtype=torch.bfloat16, grad_fn=<SumBackward0>)

In [18]:
model_trained.model.layers[5].mlp.down_proj.weight.sum()

tensor(9.8750, dtype=torch.bfloat16, grad_fn=<SumBackward0>)

# Apriel SSM for distillation

In [2]:

device = "cuda" if torch.cuda.is_available() else "cpu"


In [3]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model     = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()
apriel_model.to(device).to(dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.44it/s]


AprielForCausalLM(
  (model): AprielModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (self_attn): AprielAttention(
          (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
    (ro

In [4]:
print("N params SSM:", sum(p.numel() for p in apriel_model.parameters() if p.requires_grad)/1e9)


N params SSM: 4.83207168


In [5]:

apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, 
                                    hidden_size=config.hidden_size,
                                    intermediate_size=config.intermediate_size,
                                    num_hidden_layers=config.num_hidden_layers,
                                    hidden_act=config.hidden_act,
                                    initializer_range=config.initializer_range,
                                    use_cache=config.use_cache,
                                    mlp_bias=config.mlp_bias,
                                    tie_word_embeddings=config.tie_word_embeddings,
                                    pad_token_id=config.pad_token_id,
                                    bos_token_id=config.bos_token_id,
                                    eos_token_id=config.eos_token_id,
                                    rms_norm_eps=config.rms_norm_eps,
                                    ssm_cfg={
                                        "d_state": 64,
                                        "n_v_heads": 24,
                                        "n_qk_heads": 24,
                                        "expand": 1,
                                        "chunk_size": 128,
                                        "activation": "identity",
                                        "bias": False,
                                        "d_inner": 4104,
                                    })


In [6]:
apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)


In [7]:
print("N params SSM:", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)


N params SSM: 5.660780512


# Load State dict into SSM

In [8]:
apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=11304, bias=False)
          (conv1d): Conv1d(7176, 7176, kernel_size=(4,), stride=(1,), padding=(3,), groups=7176)
          (act): Identity()
          (out_proj): Linear(in_features=4104, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fe

In [9]:
apriel_ssm.load_state_dict(apriel_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.wei

In [17]:

apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

### Save checkpoint

In [10]:
apriel_ssm.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_base_din4104", save_config=True)




# ----

### Load Shambhavi's checkpoint

In [5]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  5.07it/s]


In [6]:
apriel_model.dtype

torch.bfloat16

In [7]:
apriel_ssm = AprielSSMForCausalLM.from_pretrained(
    "/mnt/checkpoints_fml/pretrained_models/ssm/runs/mohawk_distributed_stage2_apriel_8GPU/checkpoints/final",
    # "/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base",
      trust_remote_code=True,
      device="cuda", torch_dtype=torch.bfloat16)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards: 100%|██████████| 5/5 [00:18<00:00,  3.61s/it]


In [8]:
apriel_ssm.dtype

torch.bfloat16

In [14]:
apriel_ssm.device


device(type='cuda', index=0)

# Try a forward pass

In [11]:
device = "cuda"
input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)
batch_size = 1
max_length = 128
state = SimpleNamespace()
state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)
state.batch_size = batch_size
state.seqlen_offset = 0
static_inputs = {"inference_params": state,
        "input_ids": input_ids,
        "use_cache": True,
}

In [12]:
apriel_ssm.to(device)#.to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

In [13]:
apriel_ssm.forward(**static_inputs)

CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-5.4375,  0.3457, -2.2500,  ..., -5.2500, -4.6250, -4.2500],
         [ 1.1016, -1.2734,  0.3320,  ..., -1.8828,  0.5508, -2.0938],
         [-2.4531,  1.0078,  5.6562,  ..., -3.8906, -4.0625, -3.0625],
         ...,
         [-2.6250,  1.5234, -1.0312,  ..., -4.4062, -4.3438, -1.3594],
         [-7.0000, -2.0781,  4.6250,  ..., -4.4688, -2.9688, -1.5156],
         [-5.7188,  1.2891, -0.2109,  ..., -5.7500, -4.8438, -4.2812]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<ToCopyBackward0>), all_hidden_states=(), last_hidden_state=tensor([[[ 0.7344,  0.8555, -3.6562,  ..., -1.9688, -0.3516,  2.0625],
         [-0.0986,  0.5859, -0.0806,  ..., -0.3965, -0.0229, -0.0219],
         [ 0.3477,  0.5977, -1.0000,  ..., -0.6367, -1.1172, -0.6797],
         ...,
         [ 1.2969,  0.6562, -1.9844,  ..., -0.1299, -1.5859,  2.5000],
         [ 1.8750, -0.6016, -3.3281,  ..., -0.8242,  0.6133,  3.1250],
         [ 0.2520, -0.7656,  0.6

## Load Apriel SSM into HF class

In [130]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace
import os
import shutil
# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
model_path = "/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000"
modeling_path = "/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external"
# # copy the config.json to the model path
shutil.copy(os.path.join(modeling_path, "modeling_ssm_apriel.py"), os.path.join(model_path, "modeling_ssm_apriel.py"))
shutil.copy(os.path.join(modeling_path, "configuration_ssm_apriel.py"), os.path.join(model_path, "configuration_ssm_apriel.py"))

tokenizer_path = "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"
# # cp tokenizer*
# shutil.copy(os.path.join(tokenizer_path, "tokenizer.json"), os.path.join(model_path, "tokenizer.json"))
# shutil.copy(os.path.join(tokenizer_path, "tokenizer_config.json"), os.path.join(model_path, "tokenizer_config.json"))
# shutil.copy(os.path.join(tokenizer_path, "special_tokens_map.json"), os.path.join(model_path, "special_tokens_map.json"))
# shutil.copy(os.path.join(tokenizer_path, "vocab.json"), os.path.join(model_path, "vocab.json"))


In [4]:

apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device="cuda")


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.08s/it]


In [5]:
apriel_ssm

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

In [6]:
config = apriel_ssm.config

# Mamba in Llama: SSM hybrid 

In [2]:

from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig
import torch
from mamba_ssm import MambaLMHeadModel
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace
from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig
from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM, AprielIdentityLayer
# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper
# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
base = 0.612615
layer_scores = {
    "22": 0.607389,
    "24": 0.603498,
    "19": 0.597907,
    "27": 0.597173,
    "20": 0.590442,
    "5": 0.578949,
    "4": 0.576852,
    "9": 0.576484,
    "23": 0.574833,
    "7": 0.571860,
    "8": 0.571790,
    "6": 0.571614,
    "2": 0.571330,
    "26": 0.570205,
    "11": 0.567128,
    "14": 0.566175,
    "15": 0.566076,
    "3": 0.562861,
    "1": 0.560154,
    "13": 0.559304,
    "16": 0.559017,
    "10": 0.558789,
    "12": 0.555186,
    "17": 0.554236,
    "25": 0.549215,
    "18": 0.537257,
    "0": 0.233085,
}
layer_scores = {k: base - v for k, v in layer_scores.items()}
layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])

In [14]:
layer_importanfce

[('22', 0.005226000000000064),
 ('24', 0.009117000000000042),
 ('19', 0.014708000000000054),
 ('27', 0.015442000000000067),
 ('20', 0.022173),
 ('5', 0.033665999999999974),
 ('4', 0.03576299999999999),
 ('9', 0.036131000000000024),
 ('23', 0.03778199999999998),
 ('7', 0.040754999999999986),
 ('8', 0.040825),
 ('6', 0.041001000000000065),
 ('2', 0.041285000000000016),
 ('26', 0.04241000000000006),
 ('11', 0.045487000000000055),
 ('14', 0.04644000000000004),
 ('15', 0.046539),
 ('3', 0.049754000000000076),
 ('1', 0.05246099999999998),
 ('13', 0.053311),
 ('16', 0.053598000000000035),
 ('10', 0.05382600000000004),
 ('12', 0.05742900000000006),
 ('17', 0.05837900000000007),
 ('25', 0.06340000000000001),
 ('18', 0.07535800000000004),
 ('0', 0.37953000000000003)]

In [16]:

checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
# checkpoint = "/mnt/checkpoints/upstream/Apriel-5B-Instruct-llamafied"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
device = "cuda"
n_hybrid = 14


# d_xb = config.num_key_value_heads * config.head_dim
d_inner = config.num_attention_heads * config.head_dim
d_state = config.head_dim

# config.num_hidden_layers = 4
hybrid_block_layout = ["t", "t"] * 14
for i in range(n_hybrid):
    hybrid_block_layout[int(layer_importanfce[i][0])] = "m2d"

hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),
                                              hybrid_block_layout=hybrid_block_layout,
                                              ssm_cfg={
                                                  "d_state": 64,
                                                  "n_v_heads": 24,
                                                  "n_qk_heads": 24,
                                                  "expand": 1,
                                                  "chunk_size": 128,
                                                  "activation": "identity",
                                                  "bias": False,
                                                  "d_inner": 24 * 128,  # num_heads * head_dim
                                              })
# hybrdif_apriel_config

In [19]:
hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)
# hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)
hybrid_apriel_model.to(dtype=torch.bfloat16)

In [4]:
# save random small model for debugging
# hybrid_apriel_model.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama_debug", save_config=True)

In [21]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
# checkpoint = "/mnt/checkpoints/upstream/Apriel-5B-Instruct-llamafied"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
# apriel_state_dict = apriel_model.state_dict()
# apriel_model.to(device).to(dtype=torch.bfloat16)
apriel_state_dict = apriel_model.state_dict()

Loading checkpoint shards:  50%|█████     | 1/2 [00:00<00:00,  2.38it/s]

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.47it/s]


In [22]:
hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.weight', 'model.layers.4.mixer.conv1d.bias', 'model.layers.4.mixer.out_proj.weight', 'model.layers.5.mixer.z_bias', 'model.layers.5.mixer.D', 'model.layers.5.mixer.in_proj.weight', 'model.layers.5.mixer.conv1d.weight', 'model.layers.5.mixer.conv1d.bias', 'model.layers.5.mixer.out_proj.weight', 'model.layers.6.mixer.z_bias', 'model.layers.6.mixer.D', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.7.mixer.z_bias', 'model.layers.7.mixer.D', 'model.layers.7.mixer.in_proj.weight', 'model.layers.7.mixer.conv1d.wei

In [23]:
# Innitialization using k, q, v from Apriel transformer
def expand_k_q(k):
    Hq = config.num_attention_heads
    Hk = config.num_key_value_heads
    d_head = config.head_dim
    d = k.shape[-1]
    
    # Expand k
    repeat_factor = Hq // Hk
    k_expanded = k.view(Hk, d_head, d)
    k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)
    k_expanded = k_expanded.view(d_head * Hq, d)
    return k_expanded

def mil_innit(hybrid_apriel_model, apriel_model, SSMBLOCKCLASS = AprielSSMDecoderLayer):
    for i, (block_h, block_t) in enumerate(zip(hybrid_apriel_model.model.layers, apriel_model.model.layers)):
        # print(isinstance(block_h, AprielSSMDecoderLayer))
        # print(i, block_h.__class__)
        # print(block_h.__class__.__name__, isinstance(block_h, SSMBLOCKCLASS))
        if isinstance(block_h, SSMBLOCKCLASS):
            print("Innitiating SSM layer")
            # print(block_h.mixer.n_v_heads)
            # print(block_t.self_attn.v_proj.weight.shape)
            # print(block_h.mixer.in_proj.weight.shape)

            # print(block_h.mixer.in_proj.weight.shape)
            # print(block_t.self_attn.v_proj.weight.shape)
            block_h.mlp.load_state_dict(block_t.mlp.state_dict())
            block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())
            block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())
            block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())
            # [x B C z A_log]
            # print(block_h.mixer.d_inner)
            # init x, but interleave to address GQA
            v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)
            block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)
            # init k, but interleave to address GQA
            k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)
            block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)
            # init C ewith Q
            block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)
        else:
            if not isinstance(block_h, AprielIdentityLayer):
                assert sum([p.sum() for p in block_h.state_dict().values()]) == sum([p.sum() for p in block_t.state_dict().values()])
            else:
                print("Identity layer")

mil_innit(hybrid_apriel_model, apriel_model)

Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer
Innitiating SSM layer


In [25]:
hybrid_apriel_model.config


AprielSSMHybridConfig {
  "_name_or_path": "ServiceNow-AI/Apriel-5B-Instruct",
  "architectures": [
    "AprielSSMHybridForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "auto_map": {
    "AutoConfig": "ServiceNow-AI/Apriel-5B-Instruct--configuration_apriel.AprielConfig",
    "AutoModelForCausalLM": "ServiceNow-AI/Apriel-5B-Instruct--modeling_apriel.AprielForCausalLM"
  },
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "hybrid_block_layout": [
    "t",
    "t",
    "m2d",
    "t",
    "m2d",
    "m2d",
    "m2d",
    "m2d",
    "m2d",
    "m2d",
    "t",
    "t",
    "t",
    "t",
    "t",
    "t",
    "t",
    "t",
    "t",
    "m2d",
    "m2d",
    "t",
    "m2d",
    "m2d",
    "m2d",
    "t",
    "m2d",
    "m2d"
  ],
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "max_position_embeddings": 16384,
  "mlp_bias": false,
  "model_type": "apriel_ssm_hybrid",
  "num_attention_heads": 24,

#### Save Hybrid checkpoint

In [24]:
# hybrid_apriel_model.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama", save_config=True)
# hybrid_apriel_model.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_fulltransformer_init_mambainlama", save_config=True)
# hybrid_apriel_model.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_test_init_mambainlama", save_config=True)
hybrid_apriel_model.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_leastimportant_init_mambainlama", save_config=True)



In [27]:
reloaded_model = AprielSSMHybridModel.from_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_ssm2nd_init_mambainlama", torch_dtype=torch.bfloat16, trust_remote_code=True)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.89it/s]


In [28]:
reloaded_model

AprielSSMHybridModel(
  (embed_tokens): Embedding(131072, 4096)
  (layers): ModuleList(
    (0): AprielSSMDecoderLayer(
      (mixer): DiscreteMamba2(
        (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
        (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
        (act): Identity()
        (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
      )
      (mlp): AprielMLP(
        (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
        (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
        (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
    )
    (1): AprielDecoderLayer(
      (self_attn): AprielAttention(
        (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
        (k_

### Mamba in LLama pure SSM

In [14]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.aperiel_ssm.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielSSMForCausalLM
from fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel import AprielDecoderLayer as AprielSSMDecoderLayer
from transformers.cache_utils import StaticCache
from types import SimpleNamespace

# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
device = "cuda"
# checkpoint = "/mnt/checkpoints/upstream/Apriel-5B-Instruct-llamafied"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()
apriel_model.to(device).to(dtype=torch.bfloat16)
apriel_state_dict = apriel_model.state_dict()

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  9.03it/s]


In [1]:

apriel_ssm_config = AprielSSMConfig(vocab_size=config.vocab_size, 
                                    hidden_size=config.hidden_size,
                                    intermediate_size=config.intermediate_size,
                                    num_hidden_layers=config.num_hidden_layers,
                                    hidden_act=config.hidden_act,
                                    initializer_range=config.initializer_range,
                                    use_cache=config.use_cache,
                                    mlp_bias=config.mlp_bias,
                                    tie_word_embeddings=config.tie_word_embeddings,
                                    pad_token_id=config.pad_token_id,
                                    bos_token_id=config.bos_token_id,
                                    eos_token_id=config.eos_token_id,
                                    rms_norm_eps=config.rms_norm_eps,
                                    ssm_cfg={
                                        "d_state": 64,
                                        "n_v_heads": 24,
                                        "n_qk_heads": 24,
                                        "expand": 1,
                                        "chunk_size": 128,
                                        "activation": "identity",
                                        "bias": False,
                                        "d_inner": 24 * 128,
                                    })


NameError: name 'AprielSSMConfig' is not defined

In [20]:
apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)

In [21]:
apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

In [22]:

apriel_ssm.load_state_dict(apriel_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.wei

In [23]:
mil_innit(apriel_ssm, apriel_model, AprielSSMDecoderLayer)

0 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
1 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
2 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
3 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
4 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
5 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
6 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
7 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLayer'>
Innitiating SSM layer
8 <class 'fast_llm.models.ssm.external.aperiel_ssm.modeling_ssm_apriel.AprielDecoderLaye

In [25]:
apriel_ssm.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_init_mambainlama", save_config=True)




In [124]:
device = "cpu" #if torch.cuda.is_available() else "cpu"
input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)
batch_size = 1
max_length = 128
state = SimpleNamespace()
state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)
state.batch_size = batch_size
state.seqlen_offset = 0
static_inputs = {"inference_params": state,
        "input_ids": input_ids,
        "use_cache": True,
}


torch.Size([1024, 4096])

In [None]:
hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)

In [None]:

hybrid_apriel_model.forward(**static_inputs)

In [5]:
d_xb = config.num_key_value_heads * config.head_dim
ssm_layers = [2,4,8]
attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]
model_name = "ServiceNow-AI/Apriel-5B-Instruct"
ngroups = config.num_attention_heads # n heads
d_inner = config.head_dim * config.num_attention_heads
headdim = 128 # d_state
d_state = config.head_dim
d_model = config.hidden_size    
assert d_inner == ngroups * d_state

mamba_config = AprielSSMConfig(
    ssm_cfg={
            "d_state": 64,
            "n_v_heads": 24,
            "n_qk_heads": 24,
            "expand": 1,
            "chunk_size": 128,
            "activation": "identity",
            "bias": False,
            "d_inner": 24 * headdim,  # num_heads * head_dim
        },
    vocab_size=config.vocab_size, 
    hidden_size=config.hidden_size,
    intermediate_size=config.intermediate_size,
    num_hidden_layers=config.num_hidden_layers,
    hidden_act=config.hidden_act,
    initializer_range=config.initializer_range,
    use_cache=config.use_cache,
    mlp_bias=config.mlp_bias,
    tie_word_embeddings=config.tie_word_embeddings,
    pad_token_id=config.pad_token_id,
    bos_token_id=config.bos_token_id,
    eos_token_id=config.eos_token_id,
    head_dim=config.head_dim,
    rms_norm_eps=config.rms_norm_eps
)

In [None]:
student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, 
                                                                     mamba_config, 
                                                                     attn_layers=attn_layers, 
                                                                     init_with_kqvo=True, 
                                                                     attn_implementation="flash_attention_2")


### Llamba

In [50]:
from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel, LlambaConfig

In [51]:
llamba = LlambaLMHeadModel.from_pretrained("/mnt/checkpoints_fml/pretrained_models/llamba-1b/mohawk_distributed_stage2_from_final/checkpoints/mohawk_step9000",
                                           use_safetensors=False,
                                           torch_dtype=torch.bfloat16,
                                           trust_remote_code=True,
                                           device="cuda")

llamba.to(device).to(dtype=torch.bfloat16)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [52]:
from cartesia_pytorch.Llamba.llamba import LlambaLMHeadModel, LlambaConfig
import torch
import json

In [55]:
state_dict = torch.load("/mnt/checkpoints_fml/pretrained_models/llamba-1b/mohawk_distributed_stage2_from_final/checkpoints/mohawk_final/pytorch_model.bin",  weights_only=False)



RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [48]:
config = json.load(open("/mnt/checkpoints_fml/pretrained_models/llamba-1b/mohawk_distributed_stage2_from_final/checkpoints/mohawk_final/config.json"))
llamba = LlambaLMHeadModel(LlambaConfig(**config))

In [49]:
llamba.load_state_dict(state_dict, strict=True)



RuntimeError: Error(s) in loading state_dict for LlambaLMHeadModel:
	Missing key(s) in state_dict: "backbone.embedding.weight", "backbone.layers.0.mixer.z_bias", "backbone.layers.0.mixer.D", "backbone.layers.0.mixer.in_proj.weight", "backbone.layers.0.mixer.conv1d.weight", "backbone.layers.0.mixer.conv1d.bias", "backbone.layers.0.mixer.out_proj.weight", "backbone.layers.0.input_layernorm.weight", "backbone.layers.0.post_attention_layernorm.weight", "backbone.layers.0.mlp.gate_proj.weight", "backbone.layers.0.mlp.up_proj.weight", "backbone.layers.0.mlp.down_proj.weight", "backbone.layers.1.mixer.z_bias", "backbone.layers.1.mixer.D", "backbone.layers.1.mixer.in_proj.weight", "backbone.layers.1.mixer.conv1d.weight", "backbone.layers.1.mixer.conv1d.bias", "backbone.layers.1.mixer.out_proj.weight", "backbone.layers.1.input_layernorm.weight", "backbone.layers.1.post_attention_layernorm.weight", "backbone.layers.1.mlp.gate_proj.weight", "backbone.layers.1.mlp.up_proj.weight", "backbone.layers.1.mlp.down_proj.weight", "backbone.layers.2.mixer.z_bias", "backbone.layers.2.mixer.D", "backbone.layers.2.mixer.in_proj.weight", "backbone.layers.2.mixer.conv1d.weight", "backbone.layers.2.mixer.conv1d.bias", "backbone.layers.2.mixer.out_proj.weight", "backbone.layers.2.input_layernorm.weight", "backbone.layers.2.post_attention_layernorm.weight", "backbone.layers.2.mlp.gate_proj.weight", "backbone.layers.2.mlp.up_proj.weight", "backbone.layers.2.mlp.down_proj.weight", "backbone.layers.3.mixer.z_bias", "backbone.layers.3.mixer.D", "backbone.layers.3.mixer.in_proj.weight", "backbone.layers.3.mixer.conv1d.weight", "backbone.layers.3.mixer.conv1d.bias", "backbone.layers.3.mixer.out_proj.weight", "backbone.layers.3.input_layernorm.weight", "backbone.layers.3.post_attention_layernorm.weight", "backbone.layers.3.mlp.gate_proj.weight", "backbone.layers.3.mlp.up_proj.weight", "backbone.layers.3.mlp.down_proj.weight", "backbone.layers.4.mixer.z_bias", "backbone.layers.4.mixer.D", "backbone.layers.4.mixer.in_proj.weight", "backbone.layers.4.mixer.conv1d.weight", "backbone.layers.4.mixer.conv1d.bias", "backbone.layers.4.mixer.out_proj.weight", "backbone.layers.4.input_layernorm.weight", "backbone.layers.4.post_attention_layernorm.weight", "backbone.layers.4.mlp.gate_proj.weight", "backbone.layers.4.mlp.up_proj.weight", "backbone.layers.4.mlp.down_proj.weight", "backbone.layers.5.mixer.z_bias", "backbone.layers.5.mixer.D", "backbone.layers.5.mixer.in_proj.weight", "backbone.layers.5.mixer.conv1d.weight", "backbone.layers.5.mixer.conv1d.bias", "backbone.layers.5.mixer.out_proj.weight", "backbone.layers.5.input_layernorm.weight", "backbone.layers.5.post_attention_layernorm.weight", "backbone.layers.5.mlp.gate_proj.weight", "backbone.layers.5.mlp.up_proj.weight", "backbone.layers.5.mlp.down_proj.weight", "backbone.layers.6.mixer.z_bias", "backbone.layers.6.mixer.D", "backbone.layers.6.mixer.in_proj.weight", "backbone.layers.6.mixer.conv1d.weight", "backbone.layers.6.mixer.conv1d.bias", "backbone.layers.6.mixer.out_proj.weight", "backbone.layers.6.input_layernorm.weight", "backbone.layers.6.post_attention_layernorm.weight", "backbone.layers.6.mlp.gate_proj.weight", "backbone.layers.6.mlp.up_proj.weight", "backbone.layers.6.mlp.down_proj.weight", "backbone.layers.7.mixer.z_bias", "backbone.layers.7.mixer.D", "backbone.layers.7.mixer.in_proj.weight", "backbone.layers.7.mixer.conv1d.weight", "backbone.layers.7.mixer.conv1d.bias", "backbone.layers.7.mixer.out_proj.weight", "backbone.layers.7.input_layernorm.weight", "backbone.layers.7.post_attention_layernorm.weight", "backbone.layers.7.mlp.gate_proj.weight", "backbone.layers.7.mlp.up_proj.weight", "backbone.layers.7.mlp.down_proj.weight", "backbone.layers.8.mixer.z_bias", "backbone.layers.8.mixer.D", "backbone.layers.8.mixer.in_proj.weight", "backbone.layers.8.mixer.conv1d.weight", "backbone.layers.8.mixer.conv1d.bias", "backbone.layers.8.mixer.out_proj.weight", "backbone.layers.8.input_layernorm.weight", "backbone.layers.8.post_attention_layernorm.weight", "backbone.layers.8.mlp.gate_proj.weight", "backbone.layers.8.mlp.up_proj.weight", "backbone.layers.8.mlp.down_proj.weight", "backbone.layers.9.mixer.z_bias", "backbone.layers.9.mixer.D", "backbone.layers.9.mixer.in_proj.weight", "backbone.layers.9.mixer.conv1d.weight", "backbone.layers.9.mixer.conv1d.bias", "backbone.layers.9.mixer.out_proj.weight", "backbone.layers.9.input_layernorm.weight", "backbone.layers.9.post_attention_layernorm.weight", "backbone.layers.9.mlp.gate_proj.weight", "backbone.layers.9.mlp.up_proj.weight", "backbone.layers.9.mlp.down_proj.weight", "backbone.layers.10.mixer.z_bias", "backbone.layers.10.mixer.D", "backbone.layers.10.mixer.in_proj.weight", "backbone.layers.10.mixer.conv1d.weight", "backbone.layers.10.mixer.conv1d.bias", "backbone.layers.10.mixer.out_proj.weight", "backbone.layers.10.input_layernorm.weight", "backbone.layers.10.post_attention_layernorm.weight", "backbone.layers.10.mlp.gate_proj.weight", "backbone.layers.10.mlp.up_proj.weight", "backbone.layers.10.mlp.down_proj.weight", "backbone.layers.11.mixer.z_bias", "backbone.layers.11.mixer.D", "backbone.layers.11.mixer.in_proj.weight", "backbone.layers.11.mixer.conv1d.weight", "backbone.layers.11.mixer.conv1d.bias", "backbone.layers.11.mixer.out_proj.weight", "backbone.layers.11.input_layernorm.weight", "backbone.layers.11.post_attention_layernorm.weight", "backbone.layers.11.mlp.gate_proj.weight", "backbone.layers.11.mlp.up_proj.weight", "backbone.layers.11.mlp.down_proj.weight", "backbone.layers.12.mixer.z_bias", "backbone.layers.12.mixer.D", "backbone.layers.12.mixer.in_proj.weight", "backbone.layers.12.mixer.conv1d.weight", "backbone.layers.12.mixer.conv1d.bias", "backbone.layers.12.mixer.out_proj.weight", "backbone.layers.12.input_layernorm.weight", "backbone.layers.12.post_attention_layernorm.weight", "backbone.layers.12.mlp.gate_proj.weight", "backbone.layers.12.mlp.up_proj.weight", "backbone.layers.12.mlp.down_proj.weight", "backbone.layers.13.mixer.z_bias", "backbone.layers.13.mixer.D", "backbone.layers.13.mixer.in_proj.weight", "backbone.layers.13.mixer.conv1d.weight", "backbone.layers.13.mixer.conv1d.bias", "backbone.layers.13.mixer.out_proj.weight", "backbone.layers.13.input_layernorm.weight", "backbone.layers.13.post_attention_layernorm.weight", "backbone.layers.13.mlp.gate_proj.weight", "backbone.layers.13.mlp.up_proj.weight", "backbone.layers.13.mlp.down_proj.weight", "backbone.layers.14.mixer.z_bias", "backbone.layers.14.mixer.D", "backbone.layers.14.mixer.in_proj.weight", "backbone.layers.14.mixer.conv1d.weight", "backbone.layers.14.mixer.conv1d.bias", "backbone.layers.14.mixer.out_proj.weight", "backbone.layers.14.input_layernorm.weight", "backbone.layers.14.post_attention_layernorm.weight", "backbone.layers.14.mlp.gate_proj.weight", "backbone.layers.14.mlp.up_proj.weight", "backbone.layers.14.mlp.down_proj.weight", "backbone.layers.15.mixer.z_bias", "backbone.layers.15.mixer.D", "backbone.layers.15.mixer.in_proj.weight", "backbone.layers.15.mixer.conv1d.weight", "backbone.layers.15.mixer.conv1d.bias", "backbone.layers.15.mixer.out_proj.weight", "backbone.layers.15.input_layernorm.weight", "backbone.layers.15.post_attention_layernorm.weight", "backbone.layers.15.mlp.gate_proj.weight", "backbone.layers.15.mlp.up_proj.weight", "backbone.layers.15.mlp.down_proj.weight", "backbone.final_layernorm.weight". 
	Unexpected key(s) in state_dict: "model.embed_tokens.weight", "model.layers.0.self_attn.q_proj.weight", "model.layers.0.self_attn.k_proj.weight", "model.layers.0.self_attn.v_proj.weight", "model.layers.0.self_attn.o_proj.weight", "model.layers.0.mlp.gate_proj.weight", "model.layers.0.mlp.up_proj.weight", "model.layers.0.mlp.down_proj.weight", "model.layers.0.input_layernorm.weight", "model.layers.0.post_attention_layernorm.weight", "model.layers.1.self_attn.q_proj.weight", "model.layers.1.self_attn.k_proj.weight", "model.layers.1.self_attn.v_proj.weight", "model.layers.1.self_attn.o_proj.weight", "model.layers.1.mlp.gate_proj.weight", "model.layers.1.mlp.up_proj.weight", "model.layers.1.mlp.down_proj.weight", "model.layers.1.input_layernorm.weight", "model.layers.1.post_attention_layernorm.weight", "model.layers.2.self_attn.q_proj.weight", "model.layers.2.self_attn.k_proj.weight", "model.layers.2.self_attn.v_proj.weight", "model.layers.2.self_attn.o_proj.weight", "model.layers.2.mlp.gate_proj.weight", "model.layers.2.mlp.up_proj.weight", "model.layers.2.mlp.down_proj.weight", "model.layers.2.input_layernorm.weight", "model.layers.2.post_attention_layernorm.weight", "model.layers.3.self_attn.q_proj.weight", "model.layers.3.self_attn.k_proj.weight", "model.layers.3.self_attn.v_proj.weight", "model.layers.3.self_attn.o_proj.weight", "model.layers.3.mlp.gate_proj.weight", "model.layers.3.mlp.up_proj.weight", "model.layers.3.mlp.down_proj.weight", "model.layers.3.input_layernorm.weight", "model.layers.3.post_attention_layernorm.weight", "model.layers.4.self_attn.q_proj.weight", "model.layers.4.self_attn.k_proj.weight", "model.layers.4.self_attn.v_proj.weight", "model.layers.4.self_attn.o_proj.weight", "model.layers.4.mlp.gate_proj.weight", "model.layers.4.mlp.up_proj.weight", "model.layers.4.mlp.down_proj.weight", "model.layers.4.input_layernorm.weight", "model.layers.4.post_attention_layernorm.weight", "model.layers.5.self_attn.q_proj.weight", "model.layers.5.self_attn.k_proj.weight", "model.layers.5.self_attn.v_proj.weight", "model.layers.5.self_attn.o_proj.weight", "model.layers.5.mlp.gate_proj.weight", "model.layers.5.mlp.up_proj.weight", "model.layers.5.mlp.down_proj.weight", "model.layers.5.input_layernorm.weight", "model.layers.5.post_attention_layernorm.weight", "model.layers.6.self_attn.q_proj.weight", "model.layers.6.self_attn.k_proj.weight", "model.layers.6.self_attn.v_proj.weight", "model.layers.6.self_attn.o_proj.weight", "model.layers.6.mlp.gate_proj.weight", "model.layers.6.mlp.up_proj.weight", "model.layers.6.mlp.down_proj.weight", "model.layers.6.input_layernorm.weight", "model.layers.6.post_attention_layernorm.weight", "model.layers.7.self_attn.q_proj.weight", "model.layers.7.self_attn.k_proj.weight", "model.layers.7.self_attn.v_proj.weight", "model.layers.7.self_attn.o_proj.weight", "model.layers.7.mlp.gate_proj.weight", "model.layers.7.mlp.up_proj.weight", "model.layers.7.mlp.down_proj.weight", "model.layers.7.input_layernorm.weight", "model.layers.7.post_attention_layernorm.weight", "model.layers.8.self_attn.q_proj.weight", "model.layers.8.self_attn.k_proj.weight", "model.layers.8.self_attn.v_proj.weight", "model.layers.8.self_attn.o_proj.weight", "model.layers.8.mlp.gate_proj.weight", "model.layers.8.mlp.up_proj.weight", "model.layers.8.mlp.down_proj.weight", "model.layers.8.input_layernorm.weight", "model.layers.8.post_attention_layernorm.weight", "model.layers.9.self_attn.q_proj.weight", "model.layers.9.self_attn.k_proj.weight", "model.layers.9.self_attn.v_proj.weight", "model.layers.9.self_attn.o_proj.weight", "model.layers.9.mlp.gate_proj.weight", "model.layers.9.mlp.up_proj.weight", "model.layers.9.mlp.down_proj.weight", "model.layers.9.input_layernorm.weight", "model.layers.9.post_attention_layernorm.weight", "model.layers.10.self_attn.q_proj.weight", "model.layers.10.self_attn.k_proj.weight", "model.layers.10.self_attn.v_proj.weight", "model.layers.10.self_attn.o_proj.weight", "model.layers.10.mlp.gate_proj.weight", "model.layers.10.mlp.up_proj.weight", "model.layers.10.mlp.down_proj.weight", "model.layers.10.input_layernorm.weight", "model.layers.10.post_attention_layernorm.weight", "model.layers.11.self_attn.q_proj.weight", "model.layers.11.self_attn.k_proj.weight", "model.layers.11.self_attn.v_proj.weight", "model.layers.11.self_attn.o_proj.weight", "model.layers.11.mlp.gate_proj.weight", "model.layers.11.mlp.up_proj.weight", "model.layers.11.mlp.down_proj.weight", "model.layers.11.input_layernorm.weight", "model.layers.11.post_attention_layernorm.weight", "model.layers.12.self_attn.q_proj.weight", "model.layers.12.self_attn.k_proj.weight", "model.layers.12.self_attn.v_proj.weight", "model.layers.12.self_attn.o_proj.weight", "model.layers.12.mlp.gate_proj.weight", "model.layers.12.mlp.up_proj.weight", "model.layers.12.mlp.down_proj.weight", "model.layers.12.input_layernorm.weight", "model.layers.12.post_attention_layernorm.weight", "model.layers.13.self_attn.q_proj.weight", "model.layers.13.self_attn.k_proj.weight", "model.layers.13.self_attn.v_proj.weight", "model.layers.13.self_attn.o_proj.weight", "model.layers.13.mlp.gate_proj.weight", "model.layers.13.mlp.up_proj.weight", "model.layers.13.mlp.down_proj.weight", "model.layers.13.input_layernorm.weight", "model.layers.13.post_attention_layernorm.weight", "model.layers.14.self_attn.q_proj.weight", "model.layers.14.self_attn.k_proj.weight", "model.layers.14.self_attn.v_proj.weight", "model.layers.14.self_attn.o_proj.weight", "model.layers.14.mlp.gate_proj.weight", "model.layers.14.mlp.up_proj.weight", "model.layers.14.mlp.down_proj.weight", "model.layers.14.input_layernorm.weight", "model.layers.14.post_attention_layernorm.weight", "model.layers.15.self_attn.q_proj.weight", "model.layers.15.self_attn.k_proj.weight", "model.layers.15.self_attn.v_proj.weight", "model.layers.15.self_attn.o_proj.weight", "model.layers.15.mlp.gate_proj.weight", "model.layers.15.mlp.up_proj.weight", "model.layers.15.mlp.down_proj.weight", "model.layers.15.input_layernorm.weight", "model.layers.15.post_attention_layernorm.weight", "model.norm.weight". 

# Random Llamba

In [33]:
from transformers import LlamaConfig, LlamaForCausalLM


In [34]:
llama_config = LlamaConfig.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", trust_remote_code=True)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [42]:
llamba_config = LlambaConfig(**llama_config.to_dict(),
                             d_model=llama_config.hidden_size)
llamba = LlambaLMHeadModel(llamba_config)


In [43]:
llama_model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype=torch.bfloat16, trust_remote_code=True)
state_dict = llama_model.state_dict()


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [44]:

llamba.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['backbone.embedding.weight', 'backbone.layers.0.mixer.z_bias', 'backbone.layers.0.mixer.D', 'backbone.layers.0.mixer.in_proj.weight', 'backbone.layers.0.mixer.conv1d.weight', 'backbone.layers.0.mixer.conv1d.bias', 'backbone.layers.0.mixer.out_proj.weight', 'backbone.layers.0.input_layernorm.weight', 'backbone.layers.0.post_attention_layernorm.weight', 'backbone.layers.0.mlp.gate_proj.weight', 'backbone.layers.0.mlp.up_proj.weight', 'backbone.layers.0.mlp.down_proj.weight', 'backbone.layers.1.mixer.z_bias', 'backbone.layers.1.mixer.D', 'backbone.layers.1.mixer.in_proj.weight', 'backbone.layers.1.mixer.conv1d.weight', 'backbone.layers.1.mixer.conv1d.bias', 'backbone.layers.1.mixer.out_proj.weight', 'backbone.layers.1.input_layernorm.weight', 'backbone.layers.1.post_attention_layernorm.weight', 'backbone.layers.1.mlp.gate_proj.weight', 'backbone.layers.1.mlp.up_proj.weight', 'backbone.layers.1.mlp.down_proj.weight', 'backbone.layers.2.mixer.z_bias', 'backbo

In [46]:

llamba.save_pretrained("/mnt/checkpoints/ssm/llamba1b_from_llama32instruct_ssminit_rand")