In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


* only the bad tree layers

```
TimeContextBlock1d
└── tcm (Sequential)
    ├── ConvNeXtLikeBlock (kernel=7)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=19)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=31)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=59)
    │   └── dwconv → norm → GELU → pwconv1
    └── TransformerEncoderLayer
        ├── MultiHeadAttention
        │   └── k_proj, q_proj, v_proj, out_proj
        ├── LayerNorm
        └── FeedForward
            └── intermediate_dense → GELU → output_dense
└── exp_dim_conv: Conv1d(20 → 600)
```

that hangs on:


```
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.3/Add_reshape']
I     fuse_two_reshape: remove node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/Mul_0_unsqueeze0']
I     input_align_4D_mul: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/Mul_reshape']
I     fuse_two_reshape: remove node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/MatMul_0_unsqueeze1', '/backbone/stage0/stage0.6/tcm/tcm.4/attention/Softmax_0_unsqueeze1']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/Add_reshape']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/Add_1_reshape']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/Add_reshape']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_reshape']
I     fuse_two_reshape: remove node = ['/pool/Mul_1_0_unsqueeze0']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_1_reshape']
I     fuse_two_reshape: remove node = ['/pool/Mul_1_0_unsqueeze1']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Div_2mul_reshape']
I     fuse_two_reshape: remove node = ['/pool/Div_2mul_0_unsqueeze1']
I     input_align_4D_add: remove node = [], add node = ['/pool/Add_reshape']
I     remove_parallel_reshape: remove node = ['/pool/Mul_2_reshape']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_2_reshape']
I     fuse_two_reshape: remove node = ['/pool/ReduceSum_reshape']



```

In [1]:

class NetworkTest1(nn.Module):
    def __init__(self, dim=20, kernel_size=7):
        super().__init__()
        self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding="same", groups=dim)
        self.norm = nn.BatchNorm1d(dim)
        self.act = nn.GELU()
        self.pwconv1 = nn.Conv1d(dim, dim, kernel_size=1)

    def forward(self, x):
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.pwconv1(x)
        return x


NameError: name 'nn' is not defined

In [3]:
class NetworkTest1_fix(nn.Module):
    def __init__(self, dim=20, kernel_size=7):
        super().__init__()
        # Manually compute SAME padding
        self.kernel_size = kernel_size
        self.dwconv = nn.Conv1d(dim, dim, kernel_size, padding=0, groups=dim)
        self.norm = nn.BatchNorm1d(dim)
        self.act = nn.ReLU()  # GELU not fully supported in ONNX/RKNN
        self.pwconv1 = nn.Conv1d(dim, dim, kernel_size=1)

    def forward(self, x):  # x: [B, C=20, T=100]
        pad = (self.kernel_size - 1) // 2
        x = F.pad(x, (pad, pad))  # simulate 'same' padding
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.pwconv1(x)
        return x



In [4]:
class NetworkTest2(nn.Module):
    def __init__(self, dim=20):
        super().__init__()
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        # assume input x: (B, T, dim)
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        out = self.out_proj(q)  # simulate attention output
        return out


In [5]:
class NetworkTest3(nn.Module):
    def __init__(self, dim=20, num_heads=2):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        # assume input x: (B, T, dim)
        attn_output, _ = self.attn(x, x, x)
        x = self.norm(attn_output)
        return x


In [6]:
class NetworkTest4(nn.Module):
    def __init__(self, dim=20, num_heads=2, ff_dim=64):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, dim),
        )
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        x = self.norm1(attn_output)
        x_ff = self.ff(x)
        x = self.norm2(x_ff)
        return x


In [7]:
def export_onnx(model, name, shape):
    model.eval()
    dummy_input = torch.randn(*shape)

    torch.onnx.export(
        model,
        dummy_input,
        f"{name}.onnx",
        input_names=["input"],
        output_names=["output"],
        opset_version=13,
    )
    print(f"Exported {name}.onnx")

In [8]:
# Test1: [B, C, T]
model1 = NetworkTest1()
export_onnx(model1, "network_test1", shape=(1, 20, 100))

model1_fix = NetworkTest1_fix()
export_onnx(model1_fix, "network_test1_fix", shape=(1, 20, 100))


# Test2-4: [B, T, C]
model2 = NetworkTest2()
export_onnx(model2, "network_test2", shape=(1, 100, 20))

model3 = NetworkTest3()
export_onnx(model3, "network_test3", shape=(1, 100, 20))

model4 = NetworkTest4()
export_onnx(model4, "network_test4", shape=(1, 100, 20))

Exported network_test1.onnx
Exported network_test1_fix.onnx
Exported network_test2.onnx
Exported network_test3.onnx
Exported network_test4.onnx


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
