In [1]:
import os
import torch
import torchaudio
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import numpy as np
import librosa
from pystoi.stoi import stoi

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, noisy_dir, clean_dir, sample_rate=16000, max_length=50000):
        self.noisy_dir = noisy_dir
        self.clean_dir = clean_dir
        self.sample_rate = sample_rate
        self.noisy_files = sorted(os.listdir(noisy_dir))
        self.clean_files = sorted(os.listdir(clean_dir))
        self.max_length = max_length

    def __len__(self):
        return len(self.noisy_files)

    def __getitem__(self, idx):
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[idx])
        clean_path = os.path.join(self.clean_dir, self.clean_files[idx])

        noisy_waveform, _ = torchaudio.load(noisy_path)
        clean_waveform, _ = torchaudio.load(clean_path)

        if self.max_length is not None:
            noisy_waveform = self._fix_length(noisy_waveform, self.max_length)
            clean_waveform = self._fix_length(clean_waveform, self.max_length)

        return noisy_waveform.squeeze(0), clean_waveform.squeeze(0)

    def _fix_length(self, waveform, max_length):
        length = waveform.shape[-1]
        if length > max_length:
            return waveform[:, :max_length]
        elif length < max_length:
            pad_amount = max_length - length
            return torch.nn.functional.pad(waveform, (0, pad_amount))
        return waveform

In [4]:
noisy_files_path = "C:\\Users\\Ksenia\\Desktop\\train_data\\train_combined"
clean_files_path = "C:\\Users\\Ksenia\\Desktop\\train"

In [5]:
train_dataset = AudioDataset(noisy_files_path, clean_files_path)
#val_dataset = AudioDataset(noisy_files_val, clean_files_val)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
#val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [6]:
class SincConv(nn.Module):
    def __init__(self, out_channels, kernel_size, sample_rate):
        super(SincConv, self).__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate
        self.band_pass = nn.Parameter(torch.Tensor(out_channels, 2))
        self.init_kernels()

    def init_kernels(self):
        self.band_pass.data[:, 0] = torch.linspace(30, 300, self.out_channels)
        self.band_pass.data[:, 1] = torch.linspace(3000, 8000, self.out_channels)

    def forward(self, x):
        filters = self.create_filters().to(x.device)
        return nn.functional.conv1d(x, filters, stride=1, padding=self.kernel_size // 2)

    def create_filters(self):
        filters = torch.zeros(self.out_channels, 1, self.kernel_size)
        for i in range(self.out_channels):
            low, high = self.band_pass[i]
            filters[i, 0, :] = self.sinc_filter(low, high)
        return filters

    def sinc_filter(self, low, high):
        t = torch.linspace(-self.kernel_size // 2, self.kernel_size // 2, self.kernel_size)
        t = t.detach().numpy()
        sinc_filter = (np.sin(2 * np.pi * high.item() * t) - np.sin(2 * np.pi * low.item() * t)) / (np.pi * t)
        sinc_filter[t == 0] = 2 * (high.item() - low.item())
        window = 0.54 - 0.46 * np.cos(2 * np.pi * np.arange(self.kernel_size) / (self.kernel_size - 1))
        return torch.from_numpy(sinc_filter * window).float()

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual
        out = self.relu(out)
        return out

class DeNoise(nn.Module):
    def __init__(self, kernel_size, sample_rate, resnet_blocks, sinc_out_channels=10, gru_hidden_size=128, gru_layers=1):
        super(DeNoise, self).__init__()
        self.sinc_conv = SincConv(sinc_out_channels, kernel_size, sample_rate)
        self.resnet_blocks = nn.Sequential(
            *[BasicBlock(sinc_out_channels, sinc_out_channels) for _ in range(resnet_blocks)]
        )
        self.gru = nn.GRU(input_size=sinc_out_channels, hidden_size=gru_hidden_size, num_layers=gru_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(gru_hidden_size * 2, sinc_out_channels)
        self.output_conv = nn.Conv1d(sinc_out_channels, 1, kernel_size=1)

    def forward(self, x):
        x = self.sinc_conv(x)
        x = self.resnet_blocks(x)
        x = x.transpose(1, 2)
        gru_out, _ = self.gru(x)
        x = self.fc(gru_out)
        x = x.transpose(1, 2)
        x = self.output_conv(x)

        return x

In [7]:
def compute_stoi(clean_audio, enhanced_audio, sample_rate):
    min_len = min(clean_audio.size(1), enhanced_audio.size(1))
    clean_audio = clean_audio[:, :min_len].cpu().detach().numpy().squeeze().reshape(-1, 1)
    enhanced_audio = enhanced_audio[:, :min_len].cpu().detach().numpy().squeeze().reshape(-1, 1)
    score = stoi(clean_audio, enhanced_audio, sample_rate, extended=False)
    return score

In [8]:
def train(model, train_loader, criterion, optimizer, num_epochs=20, sample_rate=16000):
    model.to(device)
    total_train_stoi_score = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_batches = 0
        
        for i, (noisy, clean) in enumerate(tqdm(train_loader, desc=f'Training epoch {epoch+1}')):
            noisy = noisy.unsqueeze(1).to(device)
            clean = clean.unsqueeze(1).to(device)
            
            optimizer.zero_grad()
            
            outputs = model(noisy)
            
            loss = criterion(outputs, clean)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total_batches += 1
            if i % 31 == 0:
                model.eval()
                train_stoi_score = compute_stoi(clean, outputs, sample_rate)
                total_train_stoi_score += train_stoi_score
                model.train()
    
        avg_train_stoi_score = total_train_stoi_score / len(train_loader)

        '''model.eval()
        total_stoi_score = 0
        with torch.no_grad():
            for batch in val_loader:
                noisy_audio, clean_audio = batch
                out = model(noisy_audio)
                stoi_score = compute_stoi(clean_audio, out, sample_rate)
                total_stoi_score += stoi_score'''
    
        #avg_val_stoi_score = total_stoi_score / len(val_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Train STOI Score: {avg_train_stoi_score:.4f},") #Val STOI Score: {avg_val_stoi_score:.4f}")
        
        epoch_loss = running_loss / total_batches
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {epoch_loss:.4f}')

In [9]:
model = DeNoise(kernel_size=101, sample_rate=16000, resnet_blocks=3, sinc_out_channels=10, gru_hidden_size=128, gru_layers=2)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
train(model, train_loader, criterion, optimizer, num_epochs=250)

In [None]:
def load_audio(file_path, sr=16000):
    audio, sample_rate = librosa.load(file_path, sr=sr)
    return audio, sample_rate

In [None]:
def preprocess_audio(audio, sample_rate, target_length=16000):
    audio_tensor = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0).to(device)
    
    if audio_tensor.size(2) < target_length:
        audio_tensor = torch.nn.functional.pad(audio_tensor, (0, target_length - audio_tensor.size(2)))
    
    return audio_tensor

In [None]:
def denoise_audio(model, noisy_audio):
    with torch.no_grad():
        clean_audio = model(noisy_audio)
        clean_audio = clean_audio.squeeze().cpu().numpy()
    return clean_audio

In [22]:
import soundfile as sf

def save_audio(file_path, audio, sample_rate):
    sf.write(file_path, audio, sample_rate)

noisy_audio_path = 'C:\\Users\\Ksenia\\Desktop\\test.wav'
output_path = 'denoised_output.wav'

audio, sr = load_audio(noisy_audio_path)
preprocessed_audio = preprocess_audio(audio, sr)
clean_audio = denoise_audio(model, preprocessed_audio)
save_audio(output_path, clean_audio, sr)