In [2]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace

# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


  from .autonotebook import tqdm as notebook_tqdm


# Apriel SSM for distillation

In [3]:

device = "cuda" if torch.cuda.is_available() else "cpu"


In [4]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model     = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()
apriel_model.to(device).to(dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.90it/s]


AprielForCausalLM(
  (model): AprielModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (self_attn): AprielAttention(
          (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
    (ro

In [5]:
apriel_model.config.torch_dtype

torch.bfloat16

In [8]:
config_apriel = AprielSSMConfig.from_pretrained("/mnt/checkpoints_fml/pretrained_models/ssm/apriel_ssm_instruct_base", trust_remote_code=True)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [10]:
apriel_ssm = AprielSSMForCausalLM(apriel_ssm_config)

In [12]:
apriel_ssm.state_dict()

OrderedDict([('model.embed_tokens.weight',
              tensor([[ 0.0105,  0.0330, -0.0032,  ...,  0.0076, -0.0051,  0.0112],
                      [-0.0111, -0.0101,  0.0064,  ...,  0.0144,  0.0098, -0.0194],
                      [ 0.0301,  0.0228,  0.0105,  ..., -0.0159,  0.0112, -0.0009],
                      ...,
                      [ 0.0266,  0.0224, -0.0150,  ...,  0.0189, -0.0253, -0.0300],
                      [-0.0304,  0.0249,  0.0140,  ..., -0.0235,  0.0315, -0.0188],
                      [-0.0215, -0.0034,  0.0035,  ..., -0.0125,  0.0084,  0.0246]])),
             ('model.layers.0.mixer.z_bias',
              tensor([0., 0., 0.,  ..., 0., 0., 0.])),
             ('model.layers.0.mixer.D',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                      1., 1., 1., 1., 1., 1.])),
             ('model.layers.0.mixer.in_proj.weight',
              tensor([[ 0.0104,  0.0055, -0.0148,  ...,  0.0208, -0.0074,  0.0015],
   

In [14]:
print("N params SSM:", sum(p.numel() for p in apriel_ssm.parameters() if p.requires_grad)/1e9)


N params SSM: 5.305533088


# Load State dict into SSM

In [15]:

apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

In [16]:
apriel_ssm.load_state_dict(apriel_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.2.mixer.z_bias', 'model.layers.2.mixer.D', 'model.layers.2.mixer.in_proj.weight', 'model.layers.2.mixer.conv1d.weight', 'model.layers.2.mixer.conv1d.bias', 'model.layers.2.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.4.mixer.z_bias', 'model.layers.4.mixer.D', 'model.layers.4.mixer.in_proj.weight', 'model.layers.4.mixer.conv1d.wei

In [17]:

apriel_ssm.to(device).to(dtype=torch.bfloat16)

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

### Save checkpoint

In [2]:
apriel_ssm.save_pretrained("/mnt/checkpoints/ssm/apriel_ssm_instruct_base",
                            save_config=True)


NameError: name 'apriel_ssm' is not defined

In [19]:
apriel_ssm.model.layers[0].mixer.n_v_heads

24

In [20]:
apriel_ssm

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

# Try a forward pass

In [23]:
input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)
batch_size = 1
max_length = 128
state = SimpleNamespace()
state.key_value_memory_dict = apriel_ssm.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)
state.batch_size = batch_size
state.seqlen_offset = 0
static_inputs = {"inference_params": state,
        "input_ids": input_ids,
        "use_cache": True,
}

In [24]:
apriel_ssm.forward(**static_inputs)

CustomMambaCausalLMOutput(loss=None, logits=tensor([[[-3.0781,  2.3594,  1.4609,  ..., -2.3438, -1.9688,  0.6484],
         [-5.8125,  4.9688,  0.4414,  ..., -4.2500, -3.5156, -4.8125],
         [-5.5000,  3.3594,  1.1484,  ..., -3.4375, -2.3125, -4.4375],
         ...,
         [-2.2812,  0.1465,  2.2344,  ..., -7.6875, -3.0312, -6.2500],
         [-6.8750,  1.7812, -1.3750,  ..., -7.4688, -5.6875, -4.4062],
         [-2.0156,  2.0938,  3.1094,  ..., -3.0156, -2.1406, -2.2812]]],
       device='cuda:0', grad_fn=<ToCopyBackward0>), all_hidden_states=(), last_hidden_state=tensor([[[-1.3828,  0.0625, -2.7500,  ..., -0.6523, -0.8906,  1.4609],
         [ 2.1406, -0.0247, -3.0156,  ..., -0.0074,  1.0234,  1.3828],
         [ 1.6016, -0.7266, -1.2422,  ..., -0.4004, -0.8242, -0.5586],
         ...,
         [ 1.5234, -0.0262, -1.5469,  ..., -0.4922, -1.0078,  1.2344],
         [-0.4629, -0.6055, -1.3906,  ..., -0.9922, -0.3066,  1.1875],
         [-0.7539, -0.0243, -2.4688,  ..., -1.0625, -

## Load Apriel SSM into HF class

In [130]:
import torch
from mamba_ssm import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace
import os
import shutil
# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
model_path = "/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/apriel_ssminstr-distil-randinit-bs768-lr0.0003-sl4096_ti5000_luke_mix1/export/apriel_ssm/5000"
modeling_path = "/home/toolkit/dev/Fast-LLM/fast_llm/models/ssm/external"
# # copy the config.json to the model path
shutil.copy(os.path.join(modeling_path, "modeling_ssm_apriel.py"), os.path.join(model_path, "modeling_ssm_apriel.py"))
shutil.copy(os.path.join(modeling_path, "configuration_ssm_apriel.py"), os.path.join(model_path, "configuration_ssm_apriel.py"))

tokenizer_path = "/mnt/checkpoints/upstream/Mistral-Nemo-Base-2407/"
# # cp tokenizer*
# shutil.copy(os.path.join(tokenizer_path, "tokenizer.json"), os.path.join(model_path, "tokenizer.json"))
# shutil.copy(os.path.join(tokenizer_path, "tokenizer_config.json"), os.path.join(model_path, "tokenizer_config.json"))
# shutil.copy(os.path.join(tokenizer_path, "special_tokens_map.json"), os.path.join(model_path, "special_tokens_map.json"))
# shutil.copy(os.path.join(tokenizer_path, "vocab.json"), os.path.join(model_path, "vocab.json"))


In [4]:

apriel_ssm = AprielSSMForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, device="cuda")


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.08s/it]


In [5]:
apriel_ssm

AprielSSMForCausalLM(
  (model): AprielSSMModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (mixer): DiscreteMamba2(
          (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
          (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
          (act): Identity()
          (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
  )
  (lm_head): Linear(in_fea

In [6]:
config = apriel_ssm.config

# Mamba in Llama: SSM hybrid 

In [90]:

from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig
import torch
from mamba_ssm import MambaLMHeadModel
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from fast_llm.models.ssm.external.configuration_ssm_apriel import AprielSSMConfig
from fast_llm.models.ssm.external.modeling_ssm_apriel import AprielSSMForCausalLM
from transformers.cache_utils import StaticCache
from types import SimpleNamespace
from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig
from fast_llm.models.ssm.external.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer
# from fast_llm.models.ssm.external.__hybrid_wrapper import MambaTransformerHybridModelWrapper
# make sure the code changes reflected without reload
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [81]:

checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)

# d_xb = config.num_key_value_heads * config.head_dim
d_inner = config.num_attention_heads * config.head_dim
d_state = config.head_dim
hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),
                                              ssm_block_pattern=["m2d", "t"] * 14,
                                              ssm_cfg={
                                                  "d_state": 64,
                                                  "n_v_heads": 24,
                                                  "n_qk_heads": 24,
                                                #   "d_xb": d_xb,
                                                  "expand": 1,
                                                  "chunk_size": 128,
                                                  "activation": "identity",
                                                  "bias": False,
                                                  "d_inner": 24 * 128,  # num_heads * head_dim
                                              })
# hybrdif_apriel_config

In [87]:
hybrid_apriel_model = AprielSSMHybridModel(hybrdif_apriel_config)

In [88]:
hybrid_apriel_model.layers[0]

AprielSSMDecoderLayer(
  (mixer): DiscreteMamba2(
    (in_proj): Linear(in_features=4096, out_features=9240, bias=False)
    (conv1d): Conv1d(6144, 6144, kernel_size=(4,), stride=(1,), padding=(3,), groups=6144)
    (act): Identity()
    (out_proj): Linear(in_features=3072, out_features=4096, bias=False)
  )
  (mlp): AprielMLP(
    (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
    (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
    (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
  (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
)

In [91]:
isinstance(hybrid_apriel_model.layers[0], AprielSSMDecoderLayer)

True

In [84]:
device = "cpu" #if torch.cuda.is_available() else "cpu"
input_ids = torch.randint(0, 32000, (1, 128), dtype=torch.long, device=device)
batch_size = 1
max_length = 128
state = SimpleNamespace()
state.key_value_memory_dict = hybrid_apriel_model.allocate_inference_cache(batch_size, max_length, dtype=torch.bfloat16)
state.batch_size = batch_size
state.seqlen_offset = 0
static_inputs = {"inference_params": state,
        "input_ids": input_ids,
        "use_cache": True,
}


In [73]:
hybrid_apriel_model.to(device).to(dtype=torch.bfloat16)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 79.10 GiB of which 1.72 GiB is free. Process 191417 has 19.83 GiB memory in use. Process 1524280 has 57.54 GiB memory in use. Of the allocated memory 18.11 GiB is allocated by PyTorch, and 1.05 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [79]:

hybrid_apriel_model.forward(**static_inputs)

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 8216 (input tensor's size at dimension -1), but got split_sizes=[6144, 3072, 24]

In [9]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()
apriel_model.to(device).to(dtype=torch.bfloat16)

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.44it/s]


AprielForCausalLM(
  (model): AprielModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-27): 28 x AprielDecoderLayer(
        (self_attn): AprielAttention(
          (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): AprielRMSNorm((4096,), eps=1e-05)
    (ro

In [129]:
# Innitialization using k, q, v from Apriel transformer
def expand_k_q(k):
    Hq = config.num_attention_heads
    Hk = config.num_key_value_heads
    d_head = config.head_dim
    d = k.shape[-1]
    
    # Expand k
    repeat_factor = Hq // Hk
    k_expanded = k.view(Hk, d_head, d)
    k_expanded = k_expanded.repeat_interleave(repeat_factor, dim=0)
    k_expanded = k_expanded.view(d_head * Hq, d)
    return k_expanded

for block_h, block_t in zip(hybrid_apriel_model.layers, apriel_model.model.layers):
    # print(isinstance(block_h, AprielSSMDecoderLayer))
    if isinstance(block_h, AprielSSMDecoderLayer):
        # print(block_h.mixer.n_v_heads)
        # print(block_t.self_attn.v_proj.weight.shape)
        # print(block_h.mixer.in_proj.weight.shape)

        # print(block_h.mixer.in_proj.weight.shape)
        # print(block_t.self_attn.v_proj.weight.shape)
        block_h.mlp.load_state_dict(block_t.mlp.state_dict())
        block_h.input_layernorm.load_state_dict(block_t.input_layernorm.state_dict())
        block_h.post_attention_layernorm.load_state_dict(block_t.post_attention_layernorm.state_dict())
        block_h.mixer.out_proj.load_state_dict(block_t.self_attn.o_proj.state_dict())
        # [x B C z A_log]
        # print(block_h.mixer.d_inner)
        # init x, but interleave to address GQA
        v_expended = expand_k_q(block_t.self_attn.v_proj.weight.data)
        block_h.mixer.in_proj.weight.data[:block_h.mixer.d_inner, : ].copy_(v_expended)
        # init k, but interleave to address GQA
        k_expended = expand_k_q(block_t.self_attn.k_proj.weight.data)
        block_h.mixer.in_proj.weight.data[block_h.mixer.d_inner: 2*block_h.mixer.d_inner, : ].copy_(k_expended)
        # init C ewith Q
        block_h.mixer.in_proj.weight.data[2*block_h.mixer.d_inner: 3*block_h.mixer.d_inner, : ].copy_(block_t.self_attn.q_proj.weight.data)


In [124]:
block_t.self_attn.v_proj.weight.data.shape

torch.Size([1024, 4096])

In [None]:
#

In [5]:
d_xb = config.num_key_value_heads * config.head_dim
ssm_layers = [2,4,8]
attn_layers = [i for i in range(config.num_hidden_layers) if i not in ssm_layers]
model_name = "ServiceNow-AI/Apriel-5B-Instruct"
ngroups = config.num_attention_heads # n heads
d_inner = config.head_dim * config.num_attention_heads
headdim = 128 # d_state
d_state = config.head_dim
d_model = config.hidden_size    
assert d_inner == ngroups * d_state

mamba_config = AprielSSMConfig(
    ssm_cfg={
            "d_state": 64,
            "n_v_heads": 24,
            "n_qk_heads": 24,
            "expand": 1,
            "chunk_size": 128,
            "activation": "identity",
            "bias": False,
            "d_inner": 24 * headdim,  # num_heads * head_dim
        },
    vocab_size=config.vocab_size, 
    hidden_size=config.hidden_size,
    intermediate_size=config.intermediate_size,
    num_hidden_layers=config.num_hidden_layers,
    hidden_act=config.hidden_act,
    initializer_range=config.initializer_range,
    use_cache=config.use_cache,
    mlp_bias=config.mlp_bias,
    tie_word_embeddings=config.tie_word_embeddings,
    pad_token_id=config.pad_token_id,
    bos_token_id=config.bos_token_id,
    eos_token_id=config.eos_token_id,
    head_dim=config.head_dim,
    rms_norm_eps=config.rms_norm_eps
)

In [None]:
student_model = MambaTransformerHybridModelWrapper.init_distillation(None, model_name, 
                                                                     mamba_config, 
                                                                     attn_layers=attn_layers, 
                                                                     init_with_kqvo=True, 
                                                                     attn_implementation="flash_attention_2")
