https://github.com/TaoRuijie/ECAPA-TDNN/blob/main/model.py#L132 (ref)

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [None]:
m = nn.AdaptiveAvgPool1d(5)
input = torch.randn(1, 64, 8)
output = m(input)
input.shape, output.shape

(torch.Size([1, 64, 8]), torch.Size([1, 64, 5]))

In [None]:
m = nn.Conv1d(16, 33, 3, stride=2, padding=0)
input = torch.randn(20, 16, 50)
output = m(input)
input.shape, output.shape

(torch.Size([20, 16, 50]), torch.Size([20, 33, 24]))

In [None]:
class SEModule(nn.Module):
    def __init__(self, channels, bottleneck=128):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
            nn.ReLU(),
            nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
            nn.Sigmoid()
        ) 
        
    def forward(self, input):
        x = self.se(input)
        return  input * x

In [None]:
m = SEModule(16)
input = torch.randn(20, 16, 50)
output = m(input)
input.shape, output.shape

(torch.Size([20, 16, 50]), torch.Size([20, 16, 50]))

In [None]:
class Bottle2neck(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale=8):
        super().__init__()
        width = int(math.floor(planes/scale))
        self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(width*scale)
        self.nums = scale - 1
        convs = []
        bns = []
        num_pad = math.floor(kernel_size/2)*dilation
        
        for i in range(self.nums):
            convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
            bns.append(nn.BatchNorm1d(width))
            
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(planes)
        
        self.relu = nn.ReLU()
        self.width = width
        self.se = SEModule(planes)
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)
        
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i==0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
                
            sp = self.convs[i](sp)
            sp = self.relu(sp)
            sp = self.bns[i](sp)
            
            if i==0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = torch.cat((out, spx[self.nums]), 1)
        
        out = self.conv3(out)
        out = self.relu(out)
        out = self.bn3(out) 
        
        out = self.se(out)
        out += residual
        
        return out

In [None]:
class PreEmphasis(torch.nn.Module):

    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        self.register_buffer(
            'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, input: torch.tensor) -> torch.tensor:
        input = input.unsqueeze(1)
        input = F.pad(input, (1, 0), 'reflect')
        return F.conv1d(input, self.flipped_filter).squeeze(1)