In [1]:
from transformers import Glm4MoeForCausalLM
from tqdm import tqdm
from torch import nn
import torch

model = Glm4MoeForCausalLM.from_pretrained(
    'ramdisk/GLM-4.5-Air', 
    torch_dtype="auto",
    device_map="auto",
)

  from .autonotebook import tqdm as notebook_tqdm
  import pynvml  # type: ignore[import]
Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [00:37<00:00,  1.24it/s]
Some weights of the model checkpoint at ramdisk/GLM-4.5-Air were not used when initializing Glm4MoeForCausalLM: ['model.layers.46.eh_proj.weight', 'model.layers.46.embed_tokens.weight', 'model.layers.46.enorm.weight', 'model.layers.46.hnorm.weight', 'model.layers.46.input_layernorm.weight', 'model.layers.46.mlp.experts.0.down_proj.weight', 'model.layers.46.mlp.experts.0.gate_proj.weight', 'model.layers.46.mlp.experts.0.up_proj.

In [2]:
mapping = torch.load('nfs/nfs/GLM-4.5-Air-bf16/model_state_dict.pt', map_location='cpu')
keys = mapping.keys()

In [3]:
state_dict = model.state_dict()

In [4]:
for i in tqdm(range(model.config.num_hidden_layers)):
    A = f'_orig_mod.model.layers.{i}.mlp.gate_lora.A'
    if A in mapping:
        B = f'_orig_mod.model.layers.{i}.mlp.gate_lora.B'
        a = mapping[A]
        for k in range(a.shape[0]):
            W = f'model.layers.{i}.mlp.experts.{k}.gate_proj.weight'
            W = state_dict[W]
            A_ = mapping[A][k].to(W.device)
            B_ = mapping[B][k].to(W.device)
            m = torch.matmul(A_, B_) * 2.0
            W += m.T.to(W.dtype)
            
    A = f'_orig_mod.model.layers.{i}.mlp.up_lora.A'
    if A in mapping:
        B = f'_orig_mod.model.layers.{i}.mlp.up_lora.B'
        a = mapping[A]
        for k in range(a.shape[0]):
            W = f'model.layers.{i}.mlp.experts.{k}.up_proj.weight'
            W = state_dict[W]
            A_ = mapping[A][k].to(W.device)
            B_ = mapping[B][k].to(W.device)
            m = torch.matmul(A_, B_) * 2.0
            W += m.T.to(W.dtype)

    A = f'_orig_mod.model.layers.{i}.mlp.down_lora.A'
    if A in mapping:
        B = f'_orig_mod.model.layers.{i}.mlp.down_lora.B'
        a = mapping[A]
        for k in range(a.shape[0]):
            W = f'model.layers.{i}.mlp.experts.{k}.down_proj.weight'
            W = state_dict[W]
            A_ = mapping[A][k].to(W.device)
            B_ = mapping[B][k].to(W.device)
            m = torch.matmul(A_, B_) * 2.0
            W += m.T.to(W.dtype)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:02<00:00, 19.05it/s]


In [5]:
keys_lora = [k.split('.lora')[0] for k in keys if '.lora' in k]
keys_lora = sorted(list(set(keys_lora)))
for k in tqdm(keys_lora):
    k_ori = k.replace('_orig_mod.', '') + '.weight'
    post_A = '.lora_A'
    post_B = '.lora_B'
    A = k + post_A
    B = k + post_B
    W = state_dict[k_ori]
    A = mapping[A].to(W.device)
    B = mapping[B].to(W.device)
    m = torch.matmul(A.t(), B.t()) * 2.0
    W += m.T.to(W.dtype)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 322/322 [00:00<00:00, 1341.64it/s]


In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('ramdisk/GLM-4.5-Air')

q = """
Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.

terjemah ke kedah
"""

system = 'First, you try to think step-by-step in {{lang}}, after that, put your final answer within $\\boxed{}$.'
messages = [
    {"role": "system", "content": system.replace('{{lang}}', 'malay')},
    {"role": "user", "content": q},
]

row = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False
)
input_ids = tokenizer(row, add_special_tokens = False, return_tensors = 'pt').to(model.device)['input_ids']
input_ids

tensor([[151331, 151333, 151335,    198,   5338,     11,    498,   1430,    311,
           1744,   3019,  14309,  29101,    304,   8640,    352,     11,   1283,
            429,     11,   2182,    697,   1590,   4226,   2878,  57564,  78439,
           6257,  12940, 151336,    271,     33,    661,    585,  35513,  69939,
          40659,    278,     11,  25201,    524,  49289,    512,   4554,  10918,
          59982,     11,    294,   3083,   6568,   7978,   8309,   1853,    440,
          17771,  22830,    382,    465,  62462,   1466,   1962,  78407,   1466,
            198, 151337]], device='cuda:0')

In [12]:
gen_kwargs = {"max_new_tokens": 128, "do_sample": True, "temperature": 0.6, "top_p": None, "top_k": None}

output_ids = model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
response

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'[gMASK]<sop><|system|>\nFirst, you try to think step-by-step in malay, after that, put your final answer within $\\boxed{}$.<|user|>\n\nBudak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.\n\nterjemah ke kedah\n<|assistant|>\n<think>Baik, saya akan jelaskan dengan sangat terperinci dan langkah demi langkah bagaimana ayat Bahasa Melayu standard di atas ditukar kepada dialek Kedah. Saya akan membahagikan ayat kepada beberapa bahagian dan menerangkan setiap perubahan dari segi perkataan, struktur, bunyi, dan konteks budaya.\n\n## **Analisis Ayat Standard:**\n\n**Ayat Asal (Bahasa Melayu Standard):**\n\n> *Budak itu sangat nakal, pantang orang leka sedikit, duit syiling'