In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class CasualMHSA(nn.Module):
    """
    Multi-head Attention (8 heads) with Casual Masking
    """
    def __init__(self, d_model, n_head, dropout=0.1):
        super().__init__()
        self.mha = nn.MultiheadAttention(d_model,n_head,dropout=dropout,batch_first=True)
        self.d_model = d_model
    
    def forward(self, x):
        B, T, D = x.shape()
        #casual mask, '-inf' make sure model just see the past
        attn_mask = torch.triu(torch.ones(T, T, device=x.device)*float('-inf'), diagonal=1)
        x_out, _ = self.mha(x, x, x, attn_mask = attn_mask)
        return x_out

class ConformerConvModule(nn.Module):
    def __init__(self, d_model, kernel_size=15, dropout=0.1):
        super().__init__()
        self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size,
                                        padding=(kernel_size-1)//2,
                                         groups=d_model)
        self.batch_norm = nn.GroupNorm(num_groups=1, num_channels=d_model)
        self.activation = nn.SiLU()
        self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        [Batch, Time, Dim] -> [Batch, Dim, Time]
        """
        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)
        x = self.glu(x)
        x = self.depthwise_conv(x)
        x = self.batch_norm(x)
        x = self.activation(x)
        x = self.pointwise_conv2(x)
        x = self.dropout(x)
        return x.transpose(1, 2)


class FeedForwardModule(nn.Module):
    """
    Linear -> Activation -> Dropout -> Linear
    """
    def __init__(self, d_model, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer1 = nn.Linear(d_model, d_model*expansion_factor)
        self.activation = nn.SiLU()
        self.dropout1 = nn.Dropout(dropout)
        self.layer2 = nn.Linear(d_model*expansion_factor, d_model)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.layer2(x)
        x = self.dropout2(x)
        return x

In [3]:
class ConformerBlock(nn.Module):
    """
    Conformer Block include: FFN -> Conv -> MHSA -> FFN
    """
    def __init__(self, d_model, n_head, kernel_size=15, dropout=0.1):
        super().__init__()
        self.ffn1 = FeedForwardModule(d_model, dropout=dropout)
        self.conv_module = ConformerConvModule(d_model, kernel_size=kernel_size,dropout=dropout)
        self.MHSA = CasualMHSA(d_model, n_head,dropout=dropout)
        self.ffn2 = FeedForwardModule(d_model, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)
        self.final_norm = nn.LayerNorm(d_model)
    
    def forward(self,x):
        x = x + 0.5*self.ffn1(self.norm1(x))
        x = x + self.conv_module(self.norm2(x))
        x = x + self.self_attn(self.norm3(x))
        x = x + 0.5*self.ffn2(self.norm4(x))
        return self.final_norm(x)

In [4]:
class WaveformAEC(nn.Module):
    def __init__(self,
                 d_model=128,
                 n_layers=4,
                 n_head=8,
                 kernel_size=15,
                 win_len=80,
                 stride=40):
        super().__init__()
        self.win_len = win_len
        self.stride = stride
        self.d_model = d_model

        """
        Encoders
        waveform (win_len) -> feature (d_model)
        stack mix and ref -> 128+128=256
        """
        self.mix_encoder = nn.Linear(win_len, d_model)
        self.ref_encoder = nn.Linear(win_len, d_model)
        self.input = nn.linear(d_model*2, d_model)
        
        self.layers = nn.ModuleList([
            ConformerBlock(d_model, n_head, kernel_size) for _ in range(n_layers)
        ])
        """
        Decoder
        """
        self.decoder = nn.Linear(d_model, win_len)
        self.output = nn.Tanh()
    
    def forward(self, mic_wav, ref_wav):
        """
        mix_wav: [Batch, samples]
        ref_wav: [Batch, samples]
        Framing: [B, 1, L] -> [B, win_len, N_frames] -> [B, N_frames, win_len]
        """
        mic_frames = F.unfold(mic_wav.unsqueeze(1).unsqueeze(2),
                              kernel_size=(self.win_len, 1),
                              stride=(self.stride, 1)).transpose(1,2)
        ref_frames = F.unfold(ref_wav.unsqueeze(1).unsqueeze(2),
                              kernel_size=(self.win_len, 1),
                              stride=(self.stride, 1)).transpose(1,2)
        
        mic_feature = self.mix_encoder(mic_frames) 
        ref_feature = self.ref_encoder(ref_frames)
        # 2 [B, N, 128] -> [B, N, 256]
        concat_feature = torch.cat([mic_feature, ref_feature], dim=-1) 

        mask = self.input(concat_feature)
        for layer in self.layers:
            mask = layer(mask)
        mask_feature = mic_feature * mask

        out_frames = self.decoder(mask_feature)
        out_frames = self.output(out_frames)

        out_frames_transposed = out_frames.transpose(1,2)
        output_len = (out_frames.shape[1] - 1) * self.stride + self.win_len

        predicted_wav = F.fold(
            out_frames_transposed,
            output_size=(output_len, 1),
            kernel_size=(self.win_len, 1),
            stride=(self.stride, 1)
        )
        return predicted_wav.squeeze(-1).squeeze(1)