# ReDimNetNoMel Disable bad layers

===============================================================

* build new noMel model based on base line
* turn off bad layers that not converting to identity
* store to onnx and test to convert
* no need to run voice through as the identity layers kill the wights and the model output is garbage
===============================================================

In [12]:
%load_ext autoreload
%autoreload 2
## our utils
from utils.common_import import *


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
%%capture --no-display
import my_utils as myUtils
from play1_setBase_line_B2 import original_model,base_line_embedding

ReDimNetWrap expects raw 16 kHz mono audio, exactly 32 000 samples

In [14]:
from torchinfo import summary
summary(original_model, input_size=(1, 32000))

Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetWrap                                                 [1, 192]                  --
├─MelBanks: 1-1                                              [1, 72, 134]              --
│    └─Sequential: 2-1                                       [1, 72, 134]              --
│    │    └─Identity: 3-1                                    [1, 32000]                --
│    │    └─PreEmphasis: 3-2                                 [1, 32000]                --
│    │    └─MelSpectrogram: 3-3                              [1, 72, 134]              --
├─ReDimNet: 1-2                                              [1, 1152, 134]            --
│    └─Sequential: 2-2                                       [1, 1152, 134]            --
│    │    └─Conv2d: 3-4                                      [1, 16, 72, 134]          160
│    │    └─LayerNorm: 3-5                                   [1, 16, 72, 134]          32
│   

## find bad layers

problematic layer tree is:

```
TimeContextBlock1d
├── red_dim_conv (Sequential)
│   ├── Conv1d(600 → 60, kernel_size=1)
│   └── LayerNorm(C=60, data_format=channels_first)
├── tcm (Sequential)
│   ├── ConvNeXtLikeBlock (kernel=7)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=7, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=19)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=19, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=31)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=31, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   ├── ConvNeXtLikeBlock (kernel=59)
│   │   ├── dwconvs: Conv1d(60 → 60, kernel_size=59, groups=60)
│   │   ├── norm: BatchNorm1d(60)
│   │   ├── act: GELU
│   │   └── pwconv1: Conv1d(60 → 60, kernel_size=1)
│   └── TransformerEncoderLayer
│       ├── attention (MultiHeadAttention)
│       │   ├── k_proj: Linear(60 → 60)
│       │   ├── v_proj: Linear(60 → 60)
│       │   ├── q_proj: Linear(60 → 60)
│       │   └── out_proj: Linear(60 → 60)
│       ├── layer_norm: LayerNorm(60)
│       ├── feed_forward
│       │   ├── intermediate_dropout: Dropout(0.0)
│       │   ├── intermediate_dense: Linear(60 → 60)
│       │   ├── intermediate_act_fn: NewGELUActivation
│       │   ├── output_dense: Linear(60 → 60)
│       │   └── output_dropout: Dropout(0.0)
│       └── final_layer_norm: LayerNorm(60)
└── exp_dim_conv: Conv1d(60 → 600, kernel_size=1)
```

In [15]:
def find_modules_by_classname(model, classname: str):
    matches = []

    def _search(module, prefix=''):
        for name, child in module.named_children():
            full_name = f'{prefix}.{name}' if prefix else name
            if child.__class__.__name__ == classname:
                matches.append((full_name, child))
            _search(child, full_name)

    _search(model)
    return matches


# model = ReDimNet()
classname = "TimeContextBlock1d"
matches = find_modules_by_classname(original_model, classname)

for i, (path, layer) in enumerate(matches):
    print(f"[{i}] {path} → {layer.__class__.__name__}(")


[0] backbone.stage0.6 → TimeContextBlock1d(
[1] backbone.stage1.6 → TimeContextBlock1d(
[2] backbone.stage2.7 → TimeContextBlock1d(
[3] backbone.stage3.8 → TimeContextBlock1d(
[4] backbone.stage4.8 → TimeContextBlock1d(
[5] backbone.stage5.8 → TimeContextBlock1d(


##  new model with Identity layers

In [16]:
########################################
# 2) Define a Model Class without MelBanks
########################################
import torch
import torch.nn as nn

class ReDimNetNoMel(nn.Module):
    """
    A wrapper around the original ReDimNetWrap that:
      - Excludes the 'spec' (MelBanks) module
      - Uses 'backbone', 'pool', 'bn', and 'linear'
    We expect a precomputed mel spectrogram as input with shape [B, 1, n_mels, time_frames].
    """
    def __init__(self, original_wrap):
        super().__init__()
        
        # Grab references to the submodules we want to keep
        self.backbone = original_wrap.backbone
        
        # try: all layers of ConvNeXtLikeBlock
        for stage_idx, block_idx in [(0, 6), (1, 6), (2, 7), (3, 8), (4, 8) , (5, 8)]:
            for tcm_idx in range(4):  # tcm[0] to tcm[3] 
                # self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx] = nn.Identity()            ## >>>> PASS
                # self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].act = nn.SiLU()            ## >>>> NOT PASS
                ## !!! all Conv1d
                self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].dwconvs[0] = nn.Identity()
                self.backbone.__getattr__(f'stage{stage_idx}')[block_idx].tcm[tcm_idx].pwconv1 = nn.Identity()
                ## >>>> PASS

        
        
        # Replace ASTP with RKNN-safe version:
        self.pool = original_wrap.pool
        self.bn = original_wrap.bn
        self.linear = original_wrap.linear

    def forward(self, x):
        # x: shape [B, 1, n_mels, time_frames]
        # (1) Pass through the backbone
        x = self.backbone(x)    # shape might become [B, channels, frames] or similar
        print("Backbone output shape:", x.shape)  # ADD THIS LINE
        # (2) Pooling
        x = self.pool(x)        # ASTP => shape likely [B, embedding_dim]
        # (3) BatchNorm
        x = self.bn(x)
        # (4) Final linear => 192-dim (if that's your embedding size)
        x = self.linear(x)
        return x


# Create an instance of our new model that skips the MelBanks front-end
model_no_mel = ReDimNetNoMel(original_model)



### run to see if it works


In [17]:
model_no_mel.eval()  # <- this line is critical!
dummy = torch.randn(1, 1, 72, 200)
model_no_mel(dummy)

Backbone output shape: torch.Size([1, 1152, 200])


tensor([[-0.0652, -0.6970,  1.1959, -2.0029,  1.7614,  0.9690, -1.7882, -1.2452,
          4.0278, -1.4662, -0.4123,  0.0115, -1.1765,  0.5948, -1.1992,  1.1045,
          1.3353, -1.2483, -2.1060,  0.8185,  0.6821,  0.6542, -0.7927, -1.0899,
          2.2258, -2.6792, -0.3295,  1.0098, -2.5207,  0.2733,  0.0883,  1.5561,
         -0.4259, -1.4513, -0.2319, -1.2865, -2.8042,  1.6346, -0.2778,  0.0300,
          1.0930, -0.4534, -0.7703, -1.4452, -0.6968, -1.3758, -0.0780, -0.3094,
         -0.8326, -0.3129,  1.8246,  1.9380,  0.6543,  0.2583, -0.5174,  1.1465,
          0.2074,  0.2285, -0.1807, -0.6855, -0.3086, -1.5716,  1.2913, -0.1225,
         -1.6333,  0.6647,  0.7337,  2.3451, -0.2833, -1.6306, -1.1414, -1.5993,
         -0.6701,  3.3115, -1.2042,  1.7720,  0.3653,  2.2990,  0.2771,  0.3368,
          2.1119, -0.5693, -1.2039, -0.1968,  0.4912, -0.2070,  0.2387,  0.3039,
         -0.8871, -0.6586,  0.6221,  1.0567, -0.2817, -1.9561,  0.2828,  1.4142,
         -1.9354,  0.6268,  

In [18]:
# print("stage0.6 =", model_no_mel.backbone.stage4[7].tcm[0]) 
print("stage0.6 =", model_no_mel.backbone.stage0[6].tcm[0])

stage0.6 = ConvNeXtLikeBlock(
  (dwconvs): ModuleList(
    (0): Identity()
  )
  (norm): BatchNorm1d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): GELU(approximate='none')
  (pwconv1): Identity()
)


## info

In [19]:
model_no_mel.eval()


ReDimNetNoMel(
  (backbone): ReDimNet(
    (stem): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=same)
      (1): LayerNorm(C=(16,), data_format=channels_first, eps=1e-06)
      (2): to1d()
    )
    (stage0): Sequential(
      (0): weigth1d(w=(1, 1, 1, 1),sequential=False)
      (1): to2d(f=72,c=16)
      (2): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (3): ConvBlock2d(
        (conv_block): ConvNeXtLikeBlock(
          (dwconvs): ModuleList(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=same, groups=4)
          )
          (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act): GELU(approximate='none')
          (pwconv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
        )
      )
      (4): ConvBlock2d(
        (conv_block): ConvNeXtLikeBlock(
          (dwconvs): ModuleList(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=

In [20]:
summary(model_no_mel, (1, 1, 72, 200))


Backbone output shape: torch.Size([1, 1152, 200])


Layer (type:depth-idx)                                       Output Shape              Param #
ReDimNetNoMel                                                [1, 192]                  --
├─ReDimNet: 1-1                                              [1, 1152, 200]            --
│    └─Sequential: 2-1                                       [1, 1152, 200]            --
│    │    └─Conv2d: 3-1                                      [1, 16, 72, 200]          160
│    │    └─LayerNorm: 3-2                                   [1, 16, 72, 200]          32
│    │    └─to1d: 3-3                                        [1, 1152, 200]            --
│    └─Sequential: 2-2                                       [1, 1152, 200]            --
│    │    └─weigth1d: 3-4                                    [1, 1152, 200]            (1)
│    │    └─to2d: 3-5                                        [1, 16, 72, 200]          --
│    │    └─Conv2d: 3-6                                      [1, 16, 72, 200]          272
│ 

## store

In [21]:
myUtils.export_to_onnx(model_no_mel,onnx_path = "ReDimNet_no_mel.onnx")
!ls -lah ReDimNet_no_mel.onnx

Backbone output shape: torch.Size([1, 1152, 134])
Backbone output shape: torch.Size([1, 1152, 134])
Exported NHWC model to ReDimNet_no_mel_nhwc.onnx
Exported to ReDimNet_no_mel.onnx
-rw-rw-r-- 1 vlad vlad 17M Jul  6 06:53 ReDimNet_no_mel.onnx


In [22]:
import onnx
onnx_model = onnx.load("ReDimNet_no_mel.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid!")


ONNX model is valid!
