In [4]:


import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

fast_llm_path = "/home/toolkit/dev/Fast-LLM"

# add fast_llm to the python path
import sys
sys.path.append(fast_llm_path)
from apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig
from apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM

%load_ext autoreload
%autoreload 2


In [5]:
base = 0.612615
layer_scores = {
    "22": 0.607389,
    "24": 0.603498,
    "19": 0.597907,
    "27": 0.597173,
    "20": 0.590442,
    "5": 0.578949,
    "4": 0.576852,
    "9": 0.576484,
    "23": 0.574833,
    "7": 0.571860,
    "8": 0.571790,
    "6": 0.571614,
    "2": 0.571330,
    "26": 0.570205,
    "11": 0.567128,
    "14": 0.566175,
    "15": 0.566076,
    "3": 0.562861,
    "1": 0.560154,
    "13": 0.559304,
    "16": 0.559017,
    "10": 0.558789,
    "12": 0.555186,
    "17": 0.554236,
    "25": 0.549215,
    "18": 0.537257,
    "0": 0.233085,
}
layer_scores = {k: base - v for k, v in layer_scores.items()}
layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])


In [6]:
layer_importanfce

[('22', 0.005226000000000064),
 ('24', 0.009117000000000042),
 ('19', 0.014708000000000054),
 ('27', 0.015442000000000067),
 ('20', 0.022173),
 ('5', 0.033665999999999974),
 ('4', 0.03576299999999999),
 ('9', 0.036131000000000024),
 ('23', 0.03778199999999998),
 ('7', 0.040754999999999986),
 ('8', 0.040825),
 ('6', 0.041001000000000065),
 ('2', 0.041285000000000016),
 ('26', 0.04241000000000006),
 ('11', 0.045487000000000055),
 ('14', 0.04644000000000004),
 ('15', 0.046539),
 ('3', 0.049754000000000076),
 ('1', 0.05246099999999998),
 ('13', 0.053311),
 ('16', 0.053598000000000035),
 ('10', 0.05382600000000004),
 ('12', 0.05742900000000006),
 ('17', 0.05837900000000007),
 ('25', 0.06340000000000001),
 ('18', 0.07535800000000004),
 ('0', 0.37953000000000003)]

## Create hybrid with any number of SSM layers

In [7]:
checkpoint = "ServiceNow-AI/Apriel-5B-Instruct"
config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
device = "cuda"
n_hybrid = 1

index_swaped = []
hybrid_block_layout = ["t"] * config.num_hidden_layers
for i in range(n_hybrid):
    hybrid_block_layout[int(layer_importanfce[i][0])] = "m2d"
    index_swaped.append(int(layer_importanfce[i][0]))

hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),
                                              hybrid_block_layout=hybrid_block_layout,
                                              ssm_cfg={
                                                  "d_state": 64,
                                                  "n_v_heads": 24,
                                                  "n_qk_heads": 24,
                                                  "expand": 1,
                                                  "chunk_size": 128,
                                                  "activation": "identity",
                                                  "bias": False,
                                                  "d_inner": 24 * 128,  # num_heads * head_dim
                                              })

A new version of the following files was downloaded from https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct:
- configuration_apriel.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


In [8]:
hybrdif_apriel_config.hybrid_block_layout

['t',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 't',
 'm2d',
 't',
 't',
 't',
 't',
 't']

In [9]:
hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)
hybrid_apriel_model.to(dtype=torch.bfloat16)

AprielSSMHybridForCausalLM(
  (model): AprielSSMHybridModel(
    (embed_tokens): Embedding(131072, 4096)
    (layers): ModuleList(
      (0-21): 22 x AprielDecoderLayer(
        (self_attn): AprielAttention(
          (q_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=3072, out_features=4096, bias=False)
        )
        (mlp): AprielMLP(
          (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (up_proj): Linear(in_features=4096, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)
      )
      (22): AprielSSMDecoderLayer(
      

In [10]:

config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)
apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)
apriel_state_dict = apriel_model.state_dict()

A new version of the following files was downloaded from https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct:
- modeling_apriel.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 12.97it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  3.79it/s]


In [11]:
missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)

In [12]:
# unexpected will contain keys from the SSM layers we added
print("Missing keys:", missing)
# unexpected will contain keys from the transformer layers we replaced
print("Unexpected keys:", unexpected)



Missing keys: ['model.layers.22.mixer.z_bias', 'model.layers.22.mixer.D', 'model.layers.22.mixer.in_proj.weight', 'model.layers.22.mixer.conv1d.weight', 'model.layers.22.mixer.conv1d.bias', 'model.layers.22.mixer.out_proj.weight']
Unexpected keys: ['model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight']


In [13]:
# save the hybrid model
output_path = "/mnt/checkpoints/ssm/iterative_hybrids_5b"
assert len(index_swaped) == 1
layer_swaped = index_swaped[0]
hybrid_apriel_model.save_pretrained(
        f"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand"
    )
print(f"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand")


Hybrid model saved to /mnt/checkpoints/ssm/iterative_hybrids_5b/apriel_ssm_instruct5b_hybrid_23ssm_leastimportant_32h_init_rand
