In [1]:
from safetensors import safe_open
from transformers.trainer_utils import get_last_checkpoint
from glob import glob
from transformers import GptOssForCausalLM, AutoModelForCausalLM, AutoTokenizer, Mxfp4Config, TextStreamer
from peft import PeftModel
from tqdm import tqdm
import torch
torch.set_grad_enabled(False)

  import pynvml  # type: ignore[import]
Skipping import of cpp extensions due to incompatible torch version 2.9.0+cu128 for torchao version 0.15.0             Please see https://github.com/pytorch/ao/issues/2919 for more info


torch.autograd.grad_mode.set_grad_enabled(mode=False)

In [2]:
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-120b")

In [3]:
from glob import glob

files = sorted(glob('gfs-01be5b33-gpt-oss-120b-BF16/*'))
files

['gfs-01be5b33-gpt-oss-120b-BF16/315-model_state_dict.pt',
 'gfs-01be5b33-gpt-oss-120b-BF16/631-model_state_dict.pt',
 'gfs-01be5b33-gpt-oss-120b-BF16/947-model_state_dict.pt']

In [4]:
model_kwargs = dict(
    attn_implementation="kernels-community/vllm-flash-attn3",
    torch_dtype=torch.bfloat16, 
    use_cache=True,
    device_map='auto',
    dtype='auto',
)
model = AutoModelForCausalLM.from_pretrained("gfs/01be5b33/gpt-oss-120b-BF16", **model_kwargs)

`torch_dtype` is deprecated! Use `dtype` instead!


Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/73 [00:00<?, ?it/s]

In [5]:
state_dict = model.state_dict()

In [6]:
f = files[0]
mapping = torch.load(f, map_location='cpu')
keys = mapping.keys()
state_dict = model.state_dict()

In [7]:
for i in tqdm(range(model.config.num_hidden_layers)):
    A = f'model.layers.{i}.mlp.experts.gate_lora.A'
    if A in mapping:
        B = f'model.layers.{i}.mlp.experts.gate_lora.B'
        W = f'model.layers.{i}.mlp.experts.gate_up_proj'
        W = state_dict[W]
        a = mapping[A].to(W.device)
        b = mapping[B].to(W.device)

        m = torch.matmul(a, b) * 2.0
        W += m.to(W.dtype)
        
    A = f'model.layers.{i}.mlp.experts.down_lora.A'
    if A in mapping:
        B = f'model.layers.{i}.mlp.experts.down_lora.B'
        W = f'model.layers.{i}.mlp.experts.down_proj'
        W = state_dict[W]
        a = mapping[A].to(W.device)
        b = mapping[B].to(W.device)

        m = torch.matmul(a, b) * 2.0
        W += m.to(W.dtype)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:01<00:00, 23.20it/s]


In [8]:
keys_lora = [k.split('.lora')[0] for k in keys if '.lora' in k]
keys_lora = sorted(list(set(keys_lora)))
for k in tqdm(keys_lora):
    k_ori = k.replace('_orig_mod.', '') + '.weight'
    post_A = '.lora_A'
    post_B = '.lora_B'
    A = k + post_A
    B = k + post_B
    W = state_dict[k_ori]
    A = mapping[A].to(W.device)
    B = mapping[B].to(W.device)
    m = torch.matmul(A.t(), B.t()) * 2.0
    W += m.T.to(W.dtype)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:00<00:00, 2249.84it/s]


In [9]:
q = """
Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.

terjemah ke kedah
"""

system = 'First, you try to think step-by-step in {{lang}}, after that, put your final answer within $\\boxed{}$.'
messages = [
    {"role": "system", "content": system.replace('{{lang}}', 'malay')},
    {"role": "user", "content": q},
]

row = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False
)
input_ids = tokenizer(row, add_special_tokens = False, return_tensors = 'pt').to(model.device)['input_ids']
input_ids

tensor([[200006,  17360, 200008,   3575,    553,  17554, 162016,     11,    261,
           4410,   6439,   2359,  22203,    656,   7788,  17527,    558,  87447,
         100594,     25,    220,   1323,     19,     12,   3218,    198,   6576,
           3521,     25,    220,   1323,     21,     12,   2290,     12,   3114,
            279,  30377,    289,     25,  14093,    279,      2,  13888,  18403,
             25,   8450,     11,  49159,     11,   1721,     13,  21030,   2804,
            413,   7360,    395,   1753,   3176,     13, 200007, 200006,  77944,
         200008,      2,  68406,    279,   7127,     11,    481,   2075,    316,
           2411,   5983,  23541,  41570,    306,   3849,    356,     11,   1934,
            484,     11,   3006,    634,   1721,   6052,   3518, 126456, 172278,
          12083,      3,    364, 200007, 200006,   1428, 200008,    198,  75908,
            422,  15598,  32777,  11211,    280,     11,  22732,    516,  26322,
         136404,  90461,    

In [10]:
streamer = TextStreamer(tokenizer)

In [None]:
gen_kwargs = {
    "max_new_tokens": 1024, 
    "do_sample": True, 
    "temperature": 0.9, 
    "top_p": None, 
    "top_k": None,
    "streamer": streamer,
}

output_ids = model.generate(input_ids, **gen_kwargs)
response = tokenizer.batch_decode(output_ids)[0]
response

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.
Knowledge cutoff: 2024-06
Current date: 2026-01-09

Reasoning: medium

# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions

First, you try to think step-by-step in malay, after that, put your final answer within $\boxed{}$.

<|end|><|start|>user<|message|>
Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya.

terjemah ke kedah
<|end|><|start|>assistant<|channel|>analysis<|message|>Baiklah, saya akan menerangkan dengan sangat terperinci dan panjang bagaimana ayat Bahasa Melayu standard di atas, **"Budak itu sangat nakal, pantang orang leka sedikit, duit syiling pun dikebasnya."**, ditulis semula dalam dialek Kedah. Saya akan menerangkan setiap perkataan, perubahan bunyi, struktur ayat, dan nuansa budaya yang berkaitan supaya anda dapat memahami proses transformasi ayat ini dengan sangat m