In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:


# ----- Custom utility blocks -----
class to1d(nn.Module):
    def forward(self, x):
        B, C, H, W = x.shape
        return x.view(B, C * H, W)

class to2d(nn.Module):
    def __init__(self, f, c):
        super().__init__()
        self.f = f
        self.c = c
    def forward(self, x):
        B, _, T = x.shape
        return x.view(B, self.c, self.f, T)

class weigth1d(nn.Module):
    def __init__(self, w=(1, 1, 1, 1), sequential=False):
        super().__init__()
        self.w = torch.nn.Parameter(torch.ones(w))
    def forward(self, x):
        return x * self.w

class NewGELUActivation(nn.Module):
    def forward(self, x):
        return F.gelu(x)

# ----- Core Components -----
class ConvNeXtLikeBlock(nn.Module):
    def __init__(self, dim, kernel):
        super().__init__()
        self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel, padding='same', groups=dim)
        self.norm = nn.BatchNorm1d(dim)
        self.act = nn.GELU()
        self.pwconv1 = nn.Conv1d(dim, dim, kernel_size=1)

    def forward(self, x):
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.act(x)
        x = self.pwconv1(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.k_proj = nn.Linear(dim, dim)
        self.q_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        return self.out_proj(self.v_proj(x))  # stub

class FeedForward(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.intermediate_dense = nn.Linear(dim, dim)
        self.intermediate_act_fn = NewGELUActivation()
        self.output_dense = nn.Linear(dim, dim)

    def forward(self, x):
        x = self.intermediate_dense(x)
        x = self.intermediate_act_fn(x)
        x = self.output_dense(x)
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attention = MultiHeadAttention(dim)
        self.layer_norm = nn.LayerNorm(dim)
        self.feed_forward = FeedForward(dim)
        self.final_layer_norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = self.attention(x)
        x = self.layer_norm(x)
        x = self.feed_forward(x)
        x = self.final_layer_norm(x)
        return x

class TimeContextBlock1d(nn.Module):
    def __init__(self, in_dim, reduced_dim, exp_dim):
        super().__init__()
        self.red_dim_conv = nn.Sequential(
            nn.Conv1d(in_dim, reduced_dim, kernel_size=1),
            nn.LayerNorm([reduced_dim])
        )
        self.tcm = nn.Sequential(
            ConvNeXtLikeBlock(reduced_dim, 7),
            ConvNeXtLikeBlock(reduced_dim, 19),
            ConvNeXtLikeBlock(reduced_dim, 31),
            ConvNeXtLikeBlock(reduced_dim, 59),
            TransformerEncoderLayer(reduced_dim)
        )
        self.exp_dim_conv = nn.Conv1d(reduced_dim, exp_dim, kernel_size=1)

    def forward(self, x):
        x = self.red_dim_conv(x)
        x = self.tcm(x)
        x = self.exp_dim_conv(x)
        return x

# ----- Output Pooling and Final Layers -----
class ASTP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Conv1d(1800, 128, kernel_size=1)
        self.linear2 = nn.Conv1d(128, 600, kernel_size=1)

    def forward(self, x):
        return self.linear2(self.linear1(x))

# ----- ReDimNetNoMel stub main -----
class testLayer06(nn.Module):
    def __init__(self):
        super().__init__()
        self.pool = ASTP()
        self.bn = nn.BatchNorm1d(1200)
        self.linear = nn.Linear(1200, 192)

    def forward(self, x):
        x = self.pool(x)                            # (B, 600, T)
        x = torch.cat((x, x), dim=1)                # (B, 1200, T)
        x = self.bn(x)                              # BatchNorm over channel
        x = x.mean(dim=2)                           # Global average pooling → (B, 1200)
        x = self.linear(x)                          # Linear → (B, 192)
        return x




In [3]:
model = testLayer06()
model.eval()
dummy_input = torch.randn(1, 1800, 1)
torch.onnx.export(model, dummy_input, "testLayer06.onnx",
                    input_names=["log_mel"],
                    output_names=["embedding"],
                    opset_version=13)