# ReDimNetNoMel Disable bad layers

===============================================================

* build new noMel model based on base line
* turn off bad layers that not converting to identity
* store to onnx and test to convert
* no need to run voice through as the identity layers kill the wights and the model output is garbage
===============================================================

In [1]:
%load_ext autoreload
%autoreload 2
## our utils
from utils.common_import import *


2.6.0+cu124


In [2]:
%%capture --no-display
import my_utils as myUtils
from play1_setBase_line_B0 import original_model,base_line_embedding

ReDimNetWrap expects raw 16 kHz mono audio, exactly 32 000 samples

In [3]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 60, 134]              --
│    └─Sequential: 2-1                                       [1, 60, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 60, 134]              --
├─ReDimNet: 1-2                                              [1, 600, 134]             --
│    └─Sequential: 2-2                                       [1, 600, 134]             --
│    │    └─Conv2d: 3-4                                      [1, 10, 60, 134]          100
│    │    └─LayerNorm: 3-5                                   [1, 10, 60, 134]          20
│   

##  new model with Identity layers

In [4]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        
        
        ## DESERT SEARCH WHAT HANGS ###
        # self.backbone.stage0[6] = nn.Identity()
        # self.backbone.stage1[8] = nn.Identity()
        # self.backbone.stage2[8] = nn.Identity()
        # self.backbone.stage3[9] = nn.Identity()
        # self.backbone.stage4[7] = nn.Identity()
        ### >>>> PASS
        
        # try: only close TransformerEncoderLayer
        # self.backbone.stage0[6].tcm[4] = nn.Identity()
        # self.backbone.stage1[8].tcm[4] = nn.Identity()
        # self.backbone.stage2[8].tcm[4] = nn.Identity()
        # self.backbone.stage3[9].tcm[4] = nn.Identity()
        # self.backbone.stage4[7].tcm[4] = nn.Identity()
        # >>>> NOT PASS
        
        # try: close only 1 ConvNeXtLikeBlock
        # self.backbone.stage0[6].tcm[0] = nn.Identity()
        # self.backbone.stage1[8].tcm[0] = nn.Identity()
        # self.backbone.stage2[8].tcm[0] = nn.Identity()
        # self.backbone.stage3[9].tcm[0] = nn.Identity()
        # self.backbone.stage4[7].tcm[0] = nn.Identity()
        # >>> NOT PASS
        
        ## try: all layers of ConvNeXtLikeBlock
        for stage_idx, block_idx in [(0, 6), (1, 8), (2, 8), (3, 9), (4, 7)]:
            for tcm_idx in range(4):  # tcm[0] to tcm[3] 
                # self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx] = nn.Identity()            ## >>>> PASS
                # self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].act = nn.SiLU()            ## >>>> NOT PASS
                ## !!! all Conv1d
                self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].dwconvs[0] = nn.Identity()
                self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].pwconv1 = nn.Identity()
                ## >>>> PASS

        
        
        # Replace ASTP with RKNN-safe version:
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        print("Backbone output shape:", x.shape)  # ADD THIS LINE
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)



### run to see if it works


In [5]:
model_no_mel.eval()  # <- this line is critical!
dummy = torch.randn(1, 1, 60, 200)
model_no_mel(dummy)

Backbone output shape: torch.Size([1, 600, 200])


tensor([[  3.9911,  -1.0193,   1.1549,   0.8218,   1.5639,   2.2639,  -0.2115,
           1.1388,   1.6812,   6.1542,   4.8047,   0.9625,  -2.5269,  -5.0245,
          -3.7159,   3.9225,  -0.2732,   1.7035,   0.1402,  -5.7757,   3.4241,
           1.3984,  -0.4500,  -1.9297,   0.9378,  -1.4857,   4.4311,   0.6974,
          -4.4153,  -7.8101,  -2.9793,  -0.8681,  -4.3332,   1.6840,   3.7524,
           4.3297,  -0.7436,   0.3217,  -0.9835,   0.7124,   0.8157,   3.0291,
          -2.3501,  -1.4309,   4.2975,  -0.3592,  -0.4157,  -2.6449,   4.6884,
           3.3947,  -2.5064,  -3.7524,  -3.2379,   2.8387,  -4.2970,  -2.3072,
           0.9681,  -0.3885,  -1.5872,  -0.8838,  -4.6249,  -1.2078,   2.1128,
           3.8171,   0.4526,   2.6848,   0.2271,   3.8320,  -0.6987,   0.2206,
           1.6393,  -5.3172,  -6.0950,  -1.6590,   2.6062,  -1.7630,  -3.3992,
           3.4335,   1.3541,  -4.4168,  -5.2642,   3.2618,   1.3587,   7.2146,
           0.6793,  -0.4032,  -0.1059,  -1.5760,   0

## layres debug

problematic layer tree is:

```
TimeContextBlock1d
├── red_dim_conv (Sequential)
│   ├── Conv1d(600 → 60, kernel_size=1)
│   └── LayerNorm(C=60, data_format=channels_first)
├── tcm (Sequential)
│   ├── ConvNeXtLikeBlock (kernel=7)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=7, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=19)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=19, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=31)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=31, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=59)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=59, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   └── TransformerEncoderLayer
│       ├── attention (MultiHeadAttention)
│       │   ├── k_proj: Linear(60 → 60)
│       │   ├── v_proj: Linear(60 → 60)
│       │   ├── q_proj: Linear(60 → 60)
│       │   └── out_proj: Linear(60 → 60)
│       ├── layer_norm: LayerNorm(60)
│       ├── feed_forward
│       │   ├── intermediate_dropout: Dropout(0.0)
│       │   ├── intermediate_dense: Linear(60 → 60)
│       │   ├── intermediate_act_fn: NewGELUActivation
│       │   ├── output_dense: Linear(60 → 60)
│       │   └── output_dropout: Dropout(0.0)
│       └── final_layer_norm: LayerNorm(60)
└── exp_dim_conv: Conv1d(60 → 600, kernel_size=1)
```

In [6]:
for name, module in model_no_mel.named_modules():
    if isinstance(module, nn.LayerNorm):
        print("❌ Still has LayerNorm at:", name)

❌ Still has LayerNorm at: backbone.stage0.6.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage0.6.tcm.4.final_layer_norm
❌ Still has LayerNorm at: backbone.stage1.8.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage1.8.tcm.4.final_layer_norm
❌ Still has LayerNorm at: backbone.stage2.8.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage2.8.tcm.4.final_layer_norm
❌ Still has LayerNorm at: backbone.stage3.9.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage3.9.tcm.4.final_layer_norm
❌ Still has LayerNorm at: backbone.stage4.7.tcm.4.layer_norm
❌ Still has LayerNorm at: backbone.stage4.7.tcm.4.final_layer_norm


In [7]:
print("stage0.6 =", model_no_mel.backbone.stage4[7].tcm[0]) 

stage0.6 = ConvNeXtLikeBlock(
  (dwconvs): ModuleList(
    (0): Identity()
  )
  (norm): BatchNorm1d(60, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): GELU(approximate='none')
  (pwconv1): Identity()
)


## info

In [8]:
model_no_mel.eval()


ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(10,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=60,c=10)
      (2): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ResBasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv1pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=10, bias=False)
          (conv2pw): Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1))
          (bn2): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_st

In [9]:
summary(model_no_mel, (1, 1, 60, 200))


Backbone output shape: torch.Size([1, 600, 200])


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetNoMel                                                [1, 192]                  --
├─ReDimNet: 1-1                                              [1, 600, 200]             --
│    └─Sequential: 2-1                                       [1, 600, 200]             --
│    │    └─Conv2d: 3-1                                      [1, 10, 60, 200]          100
│    │    └─LayerNorm: 3-2                                   [1, 10, 60, 200]          20
│    │    └─to1d: 3-3                                        [1, 600, 200]             --
│    └─Sequential: 2-2                                       [1, 600, 200]             --
│    │    └─weigth1d: 3-4                                    [1, 600, 200]             (1)
│    │    └─to2d: 3-5                                        [1, 10, 60, 200]          --
│    │    └─Conv2d: 3-6                                      [1, 10, 60, 200]          110
│ 

## store

In [10]:
myUtils.export_to_onnx(model_no_mel,onnx_path = "ReDimNet_no_mel.onnx")
!ls -lah ReDimNet_no_mel.onnx

Backbone output shape: torch.Size([1, 600, 134])
Exported to ReDimNet_no_mel.onnx
-rw-rw-r-- 1 vlad vlad 3.9M Jun 18 18:33 ReDimNet_no_mel.onnx


In [11]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!
