In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from safetensors import safe_open
from transformers.trainer_utils import get_last_checkpoint
from glob import glob
from transformers import GptOssForCausalLM, AutoModelForCausalLM, AutoTokenizer, Mxfp4Config
from peft import PeftModel
import torch
torch.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [3]:
tensors = {}
f = os.path.join(get_last_checkpoint('malaysian-reasoning-20b-lora-r64-experts'), 'weight.pt')
with safe_open(f, framework="pt", device='cpu') as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

In [4]:
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

In [5]:
model_kwargs = dict(
    attn_implementation="kernels-community/vllm-flash-attn3",
    torch_dtype=torch.bfloat16, 
    use_cache=True, 
)
model = AutoModelForCausalLM.from_pretrained("unsloth/gpt-oss-20b-BF16", **model_kwargs).cuda()

`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 7 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 87122.04it/s]
Fetching 7 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 103017.99it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:00<00:00, 332.37it/s]


In [7]:
state_dict = model.state_dict()

In [8]:
state_dict[f'model.layers.0.mlp.experts.gate_up_proj'][0]

tensor([[ 0.0000,  0.0156,  0.0234,  ...,  0.0625,  0.0156, -0.0078],
        [ 0.0000, -0.0000,  0.0156,  ..., -0.0625,  0.0078, -0.0000],
        [ 0.0000, -0.0000,  0.0234,  ...,  0.0078,  0.0312,  0.0312],
        ...,
        [ 0.0312, -0.0625, -0.0000,  ...,  0.0312,  0.0625,  0.0469],
        [-0.0000, -0.0938, -0.0000,  ...,  0.0625,  0.0312,  0.0469],
        [ 0.0625, -0.0625, -0.0156,  ..., -0.0469, -0.0312, -0.0938]],
       device='cuda:0', dtype=torch.bfloat16)

In [9]:
total_rank = 64
top_k = model.config.num_experts_per_tok
r = total_rank // top_k
alpha = (total_rank * 2) // top_k
merge_scale = alpha / r

for i in range(model.config.num_hidden_layers):
    W = state_dict[f'model.layers.{i}.mlp.experts.gate_up_proj']
    A = tensors[f'model.layers.{i}.mlp.experts.lora_gate_up_A.e.weight'].to(W.device)
    B = tensors[f'model.layers.{i}.mlp.experts.lora_gate_up_B.e.weight'].to(W.device)
    for k in range(model.config.num_local_experts):
        a = A[k].reshape(-1, r)
        b = B[k].reshape(r, -1)

        m = torch.matmul(a, b) * merge_scale
        W[k] += m.to(W.dtype)

    W = state_dict[f'model.layers.{i}.mlp.experts.down_proj']
    A = tensors[f'model.layers.{i}.mlp.experts.lora_down_A.e.weight'].to(W.device)
    B = tensors[f'model.layers.{i}.mlp.experts.lora_down_B.e.weight'].to(W.device)
    for k in range(model.config.num_local_experts):
        a = A[k].reshape(-1, r)
        b = B[k].reshape(r, -1)

        m = torch.matmul(a, b) * merge_scale
        W[k] += m.to(W.dtype)

In [10]:
state_dict[f'model.layers.0.mlp.experts.gate_up_proj'][0]

tensor([[-0.0002,  0.0140,  0.0193,  ...,  0.0625,  0.0125, -0.0081],
        [-0.0011, -0.0045,  0.0108,  ..., -0.0625,  0.0050, -0.0029],
        [ 0.0005,  0.0016,  0.0258,  ...,  0.0080,  0.0327,  0.0322],
        ...,
        [ 0.0309, -0.0635, -0.0009,  ...,  0.0311,  0.0623,  0.0459],
        [-0.0005, -0.0957, -0.0016,  ...,  0.0623,  0.0308,  0.0454],
        [ 0.0625, -0.0620, -0.0138,  ..., -0.0466, -0.0300, -0.0938]],
       device='cuda:0', dtype=torch.bfloat16)

In [11]:
state_dict['model.layers.9.self_attn.v_proj.weight']

tensor([[ 0.0064,  0.0077, -0.0791,  ...,  0.0330,  0.0342,  0.0432],
        [-0.0146, -0.0332,  0.0131,  ..., -0.0352, -0.0640, -0.0188],
        [-0.0142,  0.0364,  0.0317,  ..., -0.0077,  0.0544, -0.0325],
        ...,
        [-0.0120, -0.0071, -0.0017,  ...,  0.0173, -0.0374,  0.0400],
        [-0.0219,  0.0141, -0.0359,  ...,  0.0256,  0.0449, -0.0396],
        [ 0.0254, -0.0132, -0.0476,  ...,  0.0330, -0.0322, -0.0256]],
       device='cuda:0', dtype=torch.bfloat16)

In [12]:
keys = tensors.keys()
keys_lora = sorted(list(set([k.split('.lora')[0] for k in keys if '.self_attn.' in k])))
for k in keys_lora:
    k_ori = k + '.weight'
    post_A = '.lora_A.e.weight'
    post_B = '.lora_B.e.weight'
    A = k + post_A
    B = k + post_B
    W = state_dict[k_ori]
    A = tensors[A].type(W.dtype).to(W.device)
    B = tensors[B].type(W.dtype).to(W.device)
    m = torch.matmul(A.t(), B.t()) * 2.0
    W += m.T.to(W.dtype)

In [13]:
state_dict['model.layers.9.self_attn.v_proj.weight']

tensor([[ 0.0065,  0.0059, -0.0796,  ...,  0.0349,  0.0349,  0.0437],
        [-0.0145, -0.0337,  0.0131,  ..., -0.0349, -0.0640, -0.0188],
        [-0.0135,  0.0364,  0.0312,  ..., -0.0074,  0.0547, -0.0325],
        ...,
        [-0.0127, -0.0066, -0.0015,  ...,  0.0172, -0.0374,  0.0398],
        [-0.0212,  0.0151, -0.0356,  ...,  0.0251,  0.0447, -0.0403],
        [ 0.0253, -0.0121, -0.0471,  ...,  0.0325, -0.0322, -0.0262]],
       device='cuda:0', dtype=torch.bfloat16)

In [14]:
q = """
Pasangan algoritma yang digunakan untuk melakukan penyulitan dan nyahsulit dikenali sebagai
A. kunci (keys)
B. Sifer (cipher)
C. Teks sifer (ciphertext)
""".strip()

system = 'First, you try to think step-by-step in {{lang}}, after that, put your final answer within $\\boxed{}$.'
messages = [
    {"role": "system", "content": system.replace('{{lang}}', 'malay')},
    {"role": "user", "content": q},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
).to(model.device)

In [15]:
q = """
Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.

terjemah ke kedah
"""

system = 'First, you try to think step-by-step in {{lang}}, after that, put your final answer within $\\boxed{}$.'
messages = [
    {"role": "system", "content": system.replace('{{lang}}', 'malay')},
    {"role": "user", "content": q},
]

row = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False
) + '<|channel|>analysis<|message|>'
input_ids = tokenizer(row, add_special_tokens = False, return_tensors = 'pt').to(model.device)['input_ids']
input_ids

tensor([[200006,  17360, 200008,   3575,    553,  17554, 162016,     11,    261,
           4410,   6439,   2359,  22203,    656,   7788,  17527,    558,  87447,
         100594,     25,    220,   1323,     19,     12,   3218,    198,   6576,
           3521,     25,    220,   1323,     20,     12,    899,     12,   1311,
            279,  30377,    289,     25,  14093,    279,      2,  13888,  18403,
             25,   8450,     11,  49159,     11,   1721,     13,  21030,   2804,
            413,   7360,    395,   1753,   3176,     13, 200007, 200006,  77944,
         200008,      2,  68406,    279,   7127,     11,    481,   2075,    316,
           2411,   5983,  23541,  41570,    306,   3849,    356,     11,   1934,
            484,     11,   3006,    634,   1721,   6052,   3518, 126456, 172278,
          12083,      3,    364, 200007, 200006,   1428, 200008,    198,  75908,
            422,  15598,  32777,  11211,    280,     11,  22732,    516,  26322,
         136404,  90461,    

In [16]:
gen_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.6, "top_p": None, "top_k": None}

output_ids = model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
response

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-12-13\n\nReasoning: medium\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions\n\nFirst, you try to think step-by-step in malay, after that, put your final answer within $\\boxed{}$.\n\n<|end|><|start|>user<|message|>\nBudak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.\n\nterjemah ke kedah\n<|end|><|start|>assistant<|channel|>analysis<|message|>Okay, so I have a local Malay sentence that needs to be translated into standard Kedah. The original sentence is:\n\n"Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya."\n\nFirst, I need to understand the meaning of the entire sentence before attempting a translation. The speaker is describing someone (a child) who is very mischievous, hard to get rid of, and even i