In [1]:
# Import necessary libraries
import json
import torch
from transformers import Qwen3Config
from nandmachine.frontend.network.qwen3 import Qwen3DecoderLayer
from nandmachine.frontend.core.graph.base import NxTracer
from nandmachine.frontend.network.torch_kernels import *

# Load config from JSON
with open('model_cards/qwen3-8B.json', 'r') as f:
    config_dict = json.load(f)

config = Qwen3Config(**config_dict)
print("Config loaded successfully:")
print(f"Hidden size: {config.hidden_size}")
print(f"Num attention heads: {config.num_attention_heads}")
print(f"Num layers: {config.num_hidden_layers}")

Config loaded successfully:
Hidden size: 4096
Num attention heads: 32
Num layers: 36


In [2]:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.passes.fake_tensor_prop import FakeTensorProp

fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

In [3]:
# Create Qwen3DecoderLayer instance
with fake_mode:
    layer = Qwen3DecoderLayer(config)
    layer.eval()
print("\nQwen3DecoderLayer created successfully")

# Trace the computation graph using NxTracer
tracer = NxTracer()

print("\nTracing the computation graph...")
graph = tracer.trace(layer)
gm = torch.fx.GraphModule(layer,graph)
print("Graph traced successfully!")
print(f"Total nodes in graph: {len(graph.nodes)}")


Qwen3DecoderLayer created successfully

Tracing the computation graph...
Graph traced successfully!
Total nodes in graph: 57


In [4]:
with fake_mode:
    input_hidden_states = torch.empty([16,1,4096])
    input_position = torch.empty([16],dtype=torch.int)


fake_prop = FakeTensorProp(gm,mode=fake_mode)

fake_prop.propagate(input_position,input_hidden_states)

In [5]:
from nandmachine.frontend.core.passes.recorder import RecorderPass

recorder_pass = RecorderPass()

graph = recorder_pass.transform(gm)

In [6]:
for node in graph.nodes:
    print(f'node name:{node.name} meta: {node.meta}') 

node name:positions meta: {'val': FakeTensor(..., size=(16,), dtype=torch.int32), 'output_shapes': (16,)}
node name:hidden_states meta: {'val': FakeTensor(..., size=(16, 1, 4096)), 'output_shapes': (16, 1, 4096)}
node name:self_attn_qkv_proj_weight meta: {'nn_module_stack': OrderedDict({'self_attn': ('self_attn', <class 'nandmachine.frontend.network.qwen3.Qwen3Attention'>), 'self_attn.qkv_proj': ('self_attn.qkv_proj', <class 'nandmachine.frontend.network.torch_kernels.QKVParallelLinear'>)}), 'val': FakeTensor(..., size=(6144, 4096)), 'output_shapes': (6144, 4096)}
node name:linear meta: {'nn_module_stack': OrderedDict({'self_attn': ('self_attn', <class 'nandmachine.frontend.network.qwen3.Qwen3Attention'>), 'self_attn.qkv_proj': ('self_attn.qkv_proj', <class 'nandmachine.frontend.network.torch_kernels.QKVParallelLinear'>)}), 'val': FakeTensor(..., size=(16, 1, 6144)), 'output_shapes': (16, 1, 6144), 'input_shapes': [(16, 1, 4096), (6144, 4096)]}
node name:split meta: {'nn_module_stack':