In [None]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
class SincConv(nn.Module):
    def __init__(self, out_channels, kernel_size, sample_rate):
        super(SincConv, self).__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.band_pass = nn.Parameter(torch.Tensor(out_channels, 2))
        self.init_kernels()

    def init_kernels(self):
        self.band_pass.data[:, 0] = torch.linspace(30, 300, self.out_channels)
        self.band_pass.data[:, 1] = torch.linspace(3000, 8000, self.out_channels)

    def forward(self, x):
        filters = self.create_filters()
        return nn.functional.conv1d(x, filters, stride=1, padding=self.kernel_size // 2)

    def create_filters(self):
        filters = torch.zeros(self.out_channels, 1, self.kernel_size)
        for i in range(self.out_channels):
            low, high = self.band_pass[i]
            filters[i, 0, :] = self.sinc_filter(low, high)
        return filters

    def sinc_filter(self, low, high):
        t = torch.linspace(-self.kernel_size // 2, self.kernel_size // 2, self.kernel_size)
        t = t.detach().numpy()
        sinc_filter = (np.sin(2 * np.pi * high.item() * t) - np.sin(2 * np.pi * low.item() * t)) / (np.pi * t)
        sinc_filter[t == 0] = 2 * (high.item() - low.item())
        window = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(self.kernel_size) / (self.kernel_size - 1))
        return torch.from_numpy(sinc_filter * window).float()

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class DeNoise(nn.Module):
    def __init__(self, kernel_size, sample_rate, resnet_blocks, sinc_out_channels=20, gru_hidden_size=128, gru_layers=2):
        super(DeNoise, self).__init__()
        self.sinc_conv = SincConv(sinc_out_channels, kernel_size, sample_rate)
        self.resnet_blocks = nn.Sequential(
            *[BasicBlock(sinc_out_channels, sinc_out_channels) for _ in range(resnet_blocks)]
        )
        self.gru = nn.GRU(input_size=sinc_out_channels, hidden_size=gru_hidden_size, num_layers=gru_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(gru_hidden_size * 2, sinc_out_channels)
        self.output_conv = nn.Conv1d(sinc_out_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.sinc_conv(x)
        x = self.resnet_blocks(x)
        x = x.transpose(1, 2)
        gru_out, _ = self.gru(x)
        x = self.fc(gru_out)
        x = x.transpose(1, 2)
        x = self.output_conv(x)

        return x