In [None]:
import torch
import torch.nn as nn
import torchaudio
from torchaudio.transforms import MFCC


In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
        super(CausalConv1d, self).__init__()
        self.pad = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=0,
            dilation=dilation,
            stride=stride
        )

    def forward(self, x):
        x = nn.functional.pad(x, (self.pad, 0))
        return self.conv(x)

In [None]:
class swishnet_module(nn.Module):
    def __init__(self,filter,length,in_channels,dilation=1,stride=1):
        super().__init__()
        self.tanh = nn.Tanh()
        self.sig = nn.Sigmoid()
        self.CausalConv1d = CausalConv1d(in_channels=in_channels, out_channels=filter, kernel_size=length, dilation=dilation, stride=stride )

    def forward(self, x):
        x = self.CausalConv1d(x)
        sig = self.sig(x)
        tanh = self.tanh(x)
        return sig * tanh

In [None]:
class swishnet(nn.Module):
    def __init__(self,sample_rate):
        super(swishnet, self).__init__()

        self.dropout = nn.Dropout(p=0.2)

        self.mfcc_transform = MFCC(
            sample_rate=sample_rate,
            n_mfcc=20,              # 추출할 MFCC 계수 수
            melkwargs={
                "n_fft": 400,
                "hop_length": 200,
                "n_mels": 40,
                "center": False,
                "power": 2.0
            }
        )
        self.block_1_up = swishnet_module(in_channels=20, filter=16, length=3, dilation=1)
        self.block_1_down = swishnet_module(in_channels=20, filter=16, length=6, dilation=1)

        self.block_2_up = swishnet_module(in_channels=32, filter=8, length=3, dilation=1)
        self.block_2_down = swishnet_module(in_channels=32, filter=8, length=6, dilation=1)

        self.block_3_up = swishnet_module(in_channels=16, filter=8, length=3, dilation=1)
        self.block_3_down = swishnet_module(in_channels=16, filter=8, length=6, dilation=1)

        self.block_4 = swishnet_module(in_channels=16, filter=16, length=3, dilation=3)
        self.block_5 = swishnet_module(in_channels=16, filter=16, length=3, dilation=2)
        self.block_6 = swishnet_module(in_channels=16, filter=16, length=3, dilation=2)
        self.block_7 = swishnet_module(in_channels=16, filter=16, length=3, dilation=2)
        self.block_8 = swishnet_module(in_channels=16, filter=32, length=3, dilation=2)

        self.cnn_9 = nn.Conv1d(in_channels=80, out_channels=4, kernel_size=1)
        self.AdaptiveAvgPool1d = nn.AdaptiveAvgPool1d(1)
        self.soft = nn.Softmax(dim=1)

    def forward(self,waveform):
        #print(waveform.shape,0)
        mfcc = self.mfcc_transform(waveform)     #16,1,20000 ->16,20 ,~
        mfcc = mfcc.squeeze(1)
        #print(mfcc.shape,1)

        x1 = self.block_1_up(mfcc)          #16,20 ,~ ->16,16 ,~
        y1 = self.block_1_down(mfcc)        #16,20 ,~ ->16,16 ,~
        #print(x1.shape,2)

        z1 = torch.cat([x1,y1],dim=1)       #16,16 ,~ *2 ->16,32 ,~
        z1 = self.dropout(z1)
        #print(z1.shape,3)

        x2 = self.block_2_up(z1)            #16,32 ,~  ->16,8 ,~
        y2 = self.block_2_down(z1)          #16,32 ,~  ->16,8 ,~
        z2 = torch.cat([x2,y2],dim=1)       #16,8 ,~ *2 ->16,16 ,~
        z2 = self.dropout(z2)
        #print(z2.shape,4)                   #torch.Size([16, 16, 99])

        x3 = self.block_3_up(z2)            #[16,16,~] -> [16,8,~]
        y3 = self.block_3_down(z2)          #[16,16,~] -> [16,8,~]
        z3 = torch.cat([x3,y3],dim=1)       #[16,8,~]*2 -> [16,16,~]
        sum3 = z2 + z3
        sum3 = self.dropout(sum3)
        # print(sum3.shape,5)                 #torch.Size([16, 16, 99]) 5

        z4 = self.block_4(sum3)             #[16,16,~] -> [16,16,~]
        sum4 = sum3 + z4
        sum4 = self.dropout(sum4)
        #print(sum4.shape,6)                 #torch.Size([16, 16, 99]) 6

        z5 = self.block_5(sum4)             #[16,16,~] -> [16,16,~]
        sum5 = sum4 + z5
        sum5 = self.dropout(sum5)
        #print(sum5.shape,7)

        z6 = self.block_6(sum5)             #[16,16,~] -> [16,16,~]
        sum6 = sum5 + z6
        sum6 = self.dropout(sum6)
        #print(sum6.shape,8)

        z7 = self.block_7(sum6)             #[16,16,~] -> [16,16,~]
        z7 = self.dropout(z7)
        #print(z7.shape,9)

        z8 = self.block_8(z7)               #[16,16,~] -> [16,16,~]
        cat8 = torch.cat([z7,z8,z5,z6],dim=1)
        cat8 = self.dropout(cat8)
        #print(cat8.shape,10)

        cnn_9 = self.cnn_9(cat8)
        cnn_9 = self.dropout(cnn_9)
        #print(cnn_9.shape,11)

        AdaptiveAvgPool1d = self.AdaptiveAvgPool1d(cnn_9).squeeze(-1)
        #print(AdaptiveAvgPool1d.shape,12)


        return AdaptiveAvgPool1d