In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


* only the bad tree layers

```
TimeContextBlock1d
└── tcm (Sequential)
    ├── ConvNeXtLikeBlock (kernel=7)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=19)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=31)
    │   └── dwconv → norm → GELU → pwconv1
    ├── ConvNeXtLikeBlock (kernel=59)
    │   └── dwconv → norm → GELU → pwconv1
    └── TransformerEncoderLayer
        ├── MultiHeadAttention
        │   └── k_proj, q_proj, v_proj, out_proj
        ├── LayerNorm
        └── FeedForward
            └── intermediate_dense → GELU → output_dense
└── exp_dim_conv: Conv1d(20 → 600)
```

that hangs on:


```
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.3/Add_reshape']
I     fuse_two_reshape: remove node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/Mul_0_unsqueeze0']
I     input_align_4D_mul: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/Mul_reshape']
I     fuse_two_reshape: remove node = ['/backbone/stage0/stage0.6/tcm/tcm.4/attention/MatMul_0_unsqueeze1', '/backbone/stage0/stage0.6/tcm/tcm.4/attention/Softmax_0_unsqueeze1']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/Add_reshape']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/tcm/tcm.4/Add_1_reshape']
I     input_align_4D_add: remove node = [], add node = ['/backbone/stage0/stage0.6/Add_reshape']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_reshape']
I     fuse_two_reshape: remove node = ['/pool/Mul_1_0_unsqueeze0']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_1_reshape']
I     fuse_two_reshape: remove node = ['/pool/Mul_1_0_unsqueeze1']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Div_2mul_reshape']
I     fuse_two_reshape: remove node = ['/pool/Div_2mul_0_unsqueeze1']
I     input_align_4D_add: remove node = [], add node = ['/pool/Add_reshape']
I     remove_parallel_reshape: remove node = ['/pool/Mul_2_reshape']
I     input_align_4D_mul: remove node = [], add node = ['/pool/Mul_2_reshape']
I     fuse_two_reshape: remove node = ['/pool/ReduceSum_reshape']



```

* the transformer-style self-attention operations, especially:
    * MatMul and Softmax operators working on [B, H, T, D] shapes.
    * The unsqueeze1 reshape points to an attempt to realign tensors for broadcasting in attention.
    * RKNN often has problems with ONNX ops that rely on dynamic reshaping or broadcasting across mismatched dims — typical in attention.

In [2]:
class NewGELUActivation(nn.Module):
    def forward(self, x):
        return F.gelu(x, approximate='none')

class FeedForward(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.intermediate_dense = nn.Linear(dim, dim)
        self.intermediate_act_fn = NewGELUActivation()
        self.output_dense = nn.Linear(dim, dim)
        self.intermediate_dropout = nn.Dropout(0.0)
        self.output_dropout = nn.Dropout(0.0)

    def forward(self, x):
        x = self.intermediate_dense(x)
        x = self.intermediate_act_fn(x)
        x = self.intermediate_dropout(x)
        x = self.output_dense(x)
        x = self.output_dropout(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        attn_weights = (q @ k.transpose(-2, -1)) * self.scale
        attn_weights = attn_weights.softmax(dim=-1)

        out = (attn_weights @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(out)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.attention = MultiHeadAttention(dim, num_heads)
        self.layer_norm = nn.LayerNorm(dim, eps=1e-6)
        self.feed_forward = FeedForward(dim)
        self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x):
        x = x + self.attention(self.layer_norm(x))
        x = x + self.feed_forward(self.final_layer_norm(x))
        return x

class Test4(nn.Module):
    def __init__(self, in_channels=600, mid_channels=20, seq_len=100):
        super().__init__()
        self.reduce = nn.Conv1d(in_channels, mid_channels, kernel_size=1)
        self.norm = nn.LayerNorm(mid_channels)  # applied after transpose to (B, T, C)
        self.transformer = TransformerEncoderLayer(mid_channels, num_heads=4)
        self.expand = nn.Conv1d(mid_channels, in_channels, kernel_size=1)

    def forward(self, x):
        # x: [B, C, T] (like your full model)
        x = self.reduce(x)  # [B, 600, T] -> [B, 20, T]
        x = x.permute(0, 2, 1)  # → [B, T, 20]
        x = self.transformer(x)  # → [B, T, 20]
        x = x.permute(0, 2, 1)  # → [B, 20, T]
        x = self.expand(x)  # [B, 600, T]
        return x


In [3]:
model = Test4()
model.eval()
dummy = torch.randn(1, 600, 100)  # [B, C, T]
torch.onnx.export(
    model,
    dummy,
    "test4_fail.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=13,
    verbose=True
)
print("Exported to test4_fail.onnx")

Exported graph: graph(%input : Float(1, 600, 100, strides=[60000, 100, 1], requires_grad=0, device=cpu),
      %reduce.weight : Float(20, 600, 1, strides=[600, 1, 1], requires_grad=1, device=cpu),
      %reduce.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.attention.q_proj.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.attention.k_proj.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.attention.v_proj.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.attention.out_proj.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.layer_norm.weight : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.layer_norm.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.feed_forward.intermediate_dense.bias : Float(20, strides=[1], requires_grad=1, device=cpu),
      %transformer.feed_forward.output_dense.bias 