In [1]:
from transformers import Qwen3ForCausalLM
from transformer_engine import pytorch as te
from transformer_engine.common import recipe
import torch.nn as nn
import torch

def convert_model(model):
    """
    Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
    Modified from https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/transformer_engine.py#L26
    Should apply after load the model with intended precision.
    """

    for name, module in model.named_children():

        if "lm_head" in name:
            continue
        if isinstance(module, nn.Linear):
            has_bias = module.bias is not None
            params_to_gather = [module.weight]
            if any(p % 16 != 0 for p in module.weight.shape):
                return
            te_module = te.Linear(
                module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
            )
            te_module.weight.copy_(module.weight)
            if has_bias:
                te_module.bias.copy_(module.bias)

            setattr(model, name, te_module)
        else:
            convert_model(module)

Supported flash-attn versions are >= 2.7.3, <= 2.8.1. Found flash-attn 2.8.3.


In [2]:
class Model(Qwen3ForCausalLM):
    def __init__(self, config):
        super().__init__(config)

    def estimate_flops(self, t = 4096):
        """
        Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 
        Borrow from https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py#L220
        """
        nparams = sum(p.numel() for p in self.parameters())
        nparams_embedding = self.model.embed_tokens.weight.numel()
        l, h, q = self.config.num_hidden_layers, self.config.num_attention_heads, self.config.head_dim
        num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
        return num_flops_per_token

In [3]:
model = Model.from_pretrained('Qwen/Qwen3-0.6B-Base')
model.estimate_flops()

5461377024

In [4]:
with torch.no_grad():
    convert_model(model)

_ = model.to('cuda')
model = torch.compile(model)

In [5]:
optim = torch.optim.AdamW(model.parameters(), lr=2e-5, fused=True)

In [6]:
fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)

In [7]:
# must divisible by 8
input_ids = torch.arange(8)[None].cuda()
labels = input_ids.clone().cuda()
b = {
    'input_ids': input_ids,
    'labels': labels,
}

In [8]:
input_ids.shape, labels.shape

(torch.Size([1, 8]), torch.Size([1, 8]))

In [9]:
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
    if fp8_recipe is not None:
        with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
            out = model(**b)

W1018 08:26:15.764000 6077 site-packages/torch/_dynamo/convert_frame.py:1016] [13/8] torch._dynamo hit config.recompile_limit (8)
W1018 08:26:15.764000 6077 site-packages/torch/_dynamo/convert_frame.py:1016] [13/8]    function: 'torch_dynamo_resume_in_forward_at_202' (/venv/main/lib/python3.12/site-packages/transformers/models/qwen3/modeling_qwen3.py:202)
W1018 08:26:15.764000 6077 site-packages/torch/_dynamo/convert_frame.py:1016] [13/8]    last reason: 13/7: past_key_values.layers[7].is_initialized == False      
W1018 08:26:15.764000 6077 site-packages/torch/_dynamo/convert_frame.py:1016] [13/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1018 08:26:15.764000 6077 site-packages/torch/_dynamo/convert_frame.py:1016] [13/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


In [10]:
out.loss

tensor(4.8944, device='cuda:0', grad_fn=<CompiledFunctionBackward>)

In [12]:
# loss = out["loss"]
# loss.backward()

In [None]:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
scheduler.step()
optim.zero_grad(set_to_none=True)