# Graph Test

In [1]:
# Import necessary libraries
import json
import torch
from transformers import Qwen3Config
from nandmachine.frontend.network.qwen3 import Qwen3DecoderLayer
from nandmachine.frontend.core.graph.base import NxTracer
from nandmachine.frontend.network.torch_kernels import *

# Load config from JSON
with open('model_cards/qwen3-8B.json', 'r') as f:
    config_dict = json.load(f)

config = Qwen3Config(**config_dict)
print("Config loaded successfully:")
print(f"Hidden size: {config.hidden_size}")
print(f"Num attention heads: {config.num_attention_heads}")
print(f"Num layers: {config.num_hidden_layers}")

Config loaded successfully:
Hidden size: 4096
Num attention heads: 32
Num layers: 36


In [2]:
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.passes.fake_tensor_prop import FakeTensorProp


fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

In [3]:
# Create Qwen3DecoderLayer instance
with fake_mode:
    layer = Qwen3DecoderLayer(config)
print("\nQwen3DecoderLayer created successfully")


Qwen3DecoderLayer created successfully


In [4]:
# Trace the computation graph using NxTracer
tracer = NxTracer()

# Set the layer to eval mode to avoid training-specific behavior
layer.eval()

print("\nTracing the computation graph...")
graph = tracer.trace(layer)
gm = torch.fx.GraphModule(layer,graph)
print("Graph traced successfully!")
print(f"Total nodes in graph: {len(graph.nodes)}")



Tracing the computation graph...
Graph traced successfully!
Total nodes in graph: 57


In [5]:
print(gm.code)




def forward(self, positions : torch.Tensor, hidden_states : torch.Tensor):
    self_attn_qkv_proj_weight = self.self_attn.qkv_proj.weight
    linear = torch._C._nn.linear(hidden_states, self_attn_qkv_proj_weight, None);  self_attn_qkv_proj_weight = None
    split = linear.split([4096, 1024, 1024], dim = -1);  linear = None
    getitem = split[0]
    getitem_1 = split[1]
    getitem_2 = split[2];  split = None
    view = getitem.view(-1, 32, 128);  getitem = None
    view_1 = getitem_1.view(-1, 8, 128);  getitem_1 = None
    view_2 = getitem_2.view(-1, 8, 128);  getitem_2 = view_2 = None
    self_attn_rotary_emb_cos_sin_cache = self.self_attn.rotary_emb.cos_sin_cache
    getitem_3 = self_attn_rotary_emb_cos_sin_cache.__getitem__(positions);  self_attn_rotary_emb_cos_sin_cache = positions = None
    chunk = getitem_3.chunk(2, dim = -1);  getitem_3 = None
    getitem_4 = chunk[0]
    getitem_5 = chunk[1];  chunk = None
    float_1 = view.float()
    chunk_1 = torch.chunk(float_1, 2, di

In [6]:


with fake_mode:
    input_hidden_states = torch.empty([16,1,4096])
    input_position = torch.empty([16],dtype=torch.int)


fake_prop = FakeTensorProp(gm,mode=fake_mode)

fake_prop.propagate(input_position,input_hidden_states)


In [7]:
layer(input_position, input_hidden_states)

FakeTensor(..., size=(16, 16, 4096), grad_fn=<AddBackward0>)

In [8]:
# Use FakeTensorMode to trace the graph
from torch._subclasses.fake_tensor import FakeTensorMode

# Allow mixing of fake tensors and real tensors (like cos_sin_cache)
fake_mode = FakeTensorMode(allow_non_fake_inputs=True)

# Create fake inputs with shape (16, 1, 4096)
with fake_mode:
    positions = torch.zeros(16, dtype=torch.long)
    hidden_states = torch.zeros(16, 1, 4096)

print(f"\nInput shapes:")
print(f"  positions: {positions.shape}")
print(f"  hidden_states: {hidden_states.shape}")



Input shapes:
  positions: torch.Size([16])
  hidden_states: torch.Size([16, 1, 4096])
