In [2]:
import torch
import torch.nn as nn

In [3]:
class Puk(nn.Module):
    def __init__(self, ):
        super(Puk, self).__init__()
        self.conv = nn.Conv1d(1, 128, kernel_size=5, bias=False)

    def forward(self, x):        
        y = self.conv(x)
        return y
    

class Puk2(nn.Module):
    def __init__(self):
        super(Puk2, self).__init__()
        self.conv = nn.Conv1d(1, 128, kernel_size=5, bias=False)
        self.ln = nn.LayerNorm(128)  # normalize over the channel dimension

    def forward(self, x):
        y = self.conv(x)  # shape: (B, C, L)
        y = y.permute(0, 2, 1).contiguous()  # shape: (B, L, C)
        y = self.ln(y)          # apply LayerNorm over C
        y = y.permute(0, 2, 1).contiguous()  # shape: (B, C, L) back
        return y
    

class Puk2_Conv2D(nn.Module):
    def __init__(self):
        super(Puk2_Conv2D, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(5, 1), bias=False)
        self.ln = nn.LayerNorm(128)  # Normalize over channel dimension (after permuting)

    def forward(self, x):
        # x: (B, 1, L, 1)
        y = self.conv(x)            # (B, 128, L_out, 1)
        y = y.squeeze(-1)           # (B, 128, L_out)

        y = y.permute(0, 2, 1)      # (B, L_out, 128)
        y = self.ln(y)              # LayerNorm over 128 channels
        y = y.permute(0, 2, 1)      # (B, 128, L_out)
        return y
    

class Puk2_Conv2D2(nn.Module):
    def __init__(self):
        super(Puk2_Conv2D2, self).__init__()
        self.conv = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=(1, 5), bias=False)
        self.ln = nn.LayerNorm(128)  # Normalize over channel dimension (after permuting)

    def forward(self, x):
        # x: (B, 1, L, 1)
        y = self.conv(x)            # (B, 128, L_out, 1)
        y = y.squeeze(-2)           # (B, 128, L_out)

        y = y.permute(0, 2, 1)      # (B, L_out, 128)
        y = self.ln(y)              # LayerNorm over 128 channels
        y = y.permute(0, 2, 1)      # (B, 128, L_out)
        return y
    
    


In [6]:

def export_to_onnx(model, dummy_input ,onnx_name):
    model.eval()
    
    #  fixed-length segments 
    torch.onnx.export(
        model,
        dummy_input,
        onnx_name,
        # input_names=["log_mel"],
        # output_names=["embedding"],
        opset_version=11
    )
    print("Exported to", onnx_name)

In [7]:
model = Puk()
print(model)
x = torch.randn(1, 1, 1000)
y = model(x)
export_to_onnx(model, x, "puk1.onnx")
print(y.shape)


model = Puk2()
print(model)
x = torch.randn(1, 1, 1000)
y = model(x)
export_to_onnx(model, x, "puk2.onnx")
print(y.shape)

model = Puk2_Conv2D()
print(model)
x = torch.randn(1, 1, 1000, 1)
y = model(x)
export_to_onnx(model, x, "puk3.onnx")
print(y.shape)

model = Puk2_Conv2D2()
print(model)
x = torch.randn(1, 1, 1, 1000)
y = model(x)
export_to_onnx(model, x, "puk4.onnx")
print(y.shape)

Puk(
  (conv): Conv1d(1, 128, kernel_size=(5,), stride=(1,), bias=False)
)
Exported to puk1.onnx
torch.Size([1, 128, 996])
Puk2(
  (conv): Conv1d(1, 128, kernel_size=(5,), stride=(1,), bias=False)
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
Exported to puk2.onnx
torch.Size([1, 128, 996])
Puk2_Conv2D(
  (conv): Conv2d(1, 128, kernel_size=(5, 1), stride=(1, 1), bias=False)
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
Exported to puk3.onnx
torch.Size([1, 128, 996])
Puk2_Conv2D2(
  (conv): Conv2d(1, 128, kernel_size=(1, 5), stride=(1, 1), bias=False)
  (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)
Exported to puk4.onnx
torch.Size([1, 128, 996])
