In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchaudio
import numpy as np
from sklearn.model_selection import train_test_split


In [4]:
class ValentiniDataset(Dataset):
    def __init__(self, noisy_files, clean_files, segment_length=48000, hop_length=24000, n_fft=512, hop_stft=128):
        self.noisy_files = noisy_files
        self.clean_files = clean_files
        self.segment_length = segment_length
        self.hop_length = hop_length
        self.n_fft = n_fft
        self.hop_stft = hop_stft
        self.segments = []

        for idx in range(len(noisy_files)):
            noisy_audio, sr = torchaudio.load(noisy_files[idx])
            audio_length = noisy_audio.size(1)
            num_segments = max(1, (audio_length - segment_length) // hop_length + 1)
            for seg_idx in range(num_segments):
                start = seg_idx * hop_length
                end = start + segment_length
                self.segments.append((idx, start, end))

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, index):
        idx, start, end = self.segments[index]
        noisy_segment, sr = torchaudio.load(self.noisy_files[idx], frame_offset=start, num_frames=self.segment_length)
        clean_segment, _ = torchaudio.load(self.clean_files[idx], frame_offset=start, num_frames=self.segment_length)
        
        noisy_segment = noisy_segment.squeeze(0)[:self.segment_length]
        clean_segment = clean_segment.squeeze(0)[:self.segment_length]
        
        if noisy_segment.size(0) < self.segment_length:
            noisy_segment = torch.nn.functional.pad(noisy_segment, (0, self.segment_length - noisy_segment.size(0)))
        if clean_segment.size(0) < self.segment_length:
            clean_segment = torch.nn.functional.pad(clean_segment, (0, self.segment_length - clean_segment.size(0)))

        noisy_stft = torch.stft(noisy_segment, n_fft=self.n_fft, hop_length=self.hop_stft, 
                                window=torch.hann_window(self.n_fft), return_complex=True)
        clean_stft = torch.stft(clean_segment, n_fft=self.n_fft, hop_length=self.hop_stft, 
                                window=torch.hann_window(self.n_fft), return_complex=True)

        noisy_mag = torch.abs(noisy_stft)
        clean_mag = torch.abs(clean_stft)
        noisy_logmag = torch.log(1 + noisy_mag).unsqueeze(0)
        clean_logmag = torch.log(1 + clean_mag).unsqueeze(0)

        noisy_logmag = noisy_logmag[:, :, :123]
        clean_logmag = clean_logmag[:, :, :123]

        return noisy_logmag, clean_logmag

In [5]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.enc1 = self.conv_block(1, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = self.conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = self.conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = self.conv_block(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, output_padding=(0, 1))
        self.dec1 = self.conv_block(128, 64)
        self.out = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        e1 = self.enc1(x)      
        p1 = self.pool1(e1)    
        e2 = self.enc2(p1)     
        p2 = self.pool2(e2)    
        e3 = self.enc3(p2)     
        p3 = self.pool3(e3)    
        b = self.bottleneck(p3) 

        u3 = self.up3(b)       
        e3_cropped = e3[:, :, :, :u3.size(3)]  
        cat3 = torch.cat([u3, e3_cropped], dim=1)  
        d3 = self.dec3(cat3)   
        u2 = self.up2(d3)      
        e2_cropped = e2[:, :, :, :u2.size(3)]  
        cat2 = torch.cat([u2, e2_cropped], dim=1)  
        d2 = self.dec2(cat2)   
        u1 = self.up1(d2)      
        u1 = nn.functional.pad(u1, (0, 2, 0, 1), mode='constant', value=0)  
        cat1 = torch.cat([u1, e1], dim=1)  
        d1 = self.dec1(cat1)   
        out = self.out(d1)     

        return out

In [6]:
from tqdm.auto import tqdm

def train_model(model, train_loader, val_loader, num_epochs, device):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    best_val_loss = float('inf')
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        model.train()
        train_loss = 0
        
        train_bar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}", leave=False)
        for noisy_batch, clean_batch in train_bar:
            noisy_batch, clean_batch = noisy_batch.to(device), clean_batch.to(device)
            optimizer.zero_grad()
            outputs = model(noisy_batch)
            loss = criterion(outputs, clean_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * noisy_batch.size(0)
            train_bar.set_postfix({'loss': loss.item()})
        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0
        
        val_bar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}", leave=False)
        with torch.no_grad():
            for noisy_batch, clean_batch in val_bar:
                noisy_batch, clean_batch = noisy_batch.to(device), clean_batch.to(device)
                outputs = model(noisy_batch)
                loss = criterion(outputs, clean_batch)
                val_loss += loss.item() * noisy_batch.size(0)
                val_bar.set_postfix({'loss': loss.item()})
        val_loss /= len(val_loader.dataset)

        scheduler.step(val_loss)
        
        tqdm.write(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), '/kaggle/working/best_model.pth')

In [7]:
def denoise_audio(model, noisy_file, n_fft=512, hop_stft=128, device='cuda'):
    model.eval()
    noisy_audio, sr = torchaudio.load(noisy_file)
    noisy_stft = torch.stft(noisy_audio.squeeze(0), n_fft=n_fft, hop_length=hop_stft, 
                            window=torch.hann_window(n_fft), return_complex=True)
    noisy_mag = torch.abs(noisy_stft)
    noisy_phase = torch.angle(noisy_stft)
    noisy_logmag = torch.log(1 + noisy_mag).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        pred_logmag = model(noisy_logmag)
    pred_mag = torch.exp(pred_logmag.squeeze(0).squeeze(0)) - 1
    pred_stft = pred_mag * torch.exp(1j * noisy_phase)
    denoised_audio = torch.istft(pred_stft, n_fft=n_fft, hop_length=hop_stft, 
                                 window=torch.hann_window(n_fft))
    return denoised_audio

In [8]:
import torchaudio

noisy_dir = '/kaggle/input/valentini-noisy/noisy_trainset_28spk_wav'
clean_dir = '/kaggle/input/valentini-noisy/clean_trainset_28spk_wav'

print("Collecting noisy file paths...")
noisy_files = sorted([os.path.join(noisy_dir, f) for f in os.listdir(noisy_dir) if f.endswith('.wav')])
print("Collecting clean file paths...")
clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')])

sample_rate = torchaudio.info(noisy_files[0]).sample_rate
print(f"Dataset sample rate: {sample_rate} Hz")
assert sample_rate == 48000, "Expected 48kHz for Valentini dataset!"

noisy_train, noisy_val, clean_train, clean_val = train_test_split(
    noisy_files, clean_files, test_size=0.2, random_state=42
)

train_dataset = ValentiniDataset(noisy_train, clean_train, segment_length=48000, hop_length=24000)
val_dataset = ValentiniDataset(noisy_val, clean_val, segment_length=48000, hop_length=24000)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Collecting noisy file paths...
Collecting clean file paths...
Dataset sample rate: 48000 Hz
Training samples: 40000, Validation samples: 10256


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = UNet().to(device)

train_model(model, train_loader, val_loader, num_epochs=50, device=device)

Using device: cuda


Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

Training Epoch 1/50:   0%|          | 0/2500 [00:00<?, ?it/s]

In [None]:
import ipywidgets as widgets
from IPython.display import Audio, display

random_idx = random.randint(0, len(noisy_val) - 1)
noisy_file = noisy_val[random_idx]
print(f"Selected file: {noisy_file}")

def denoise_audio_full(model, noisy_file, n_fft=512, hop_stft=128, device=device):
    noisy_audio, sr = torchaudio.load(noisy_file)
    noisy_stft = torch.stft(noisy_audio.squeeze(0), n_fft=n_fft, hop_length=hop_stft, 
                            window=torch.hann_window(n_fft), return_complex=True)
    noisy_mag = torch.abs(noisy_stft)
    noisy_phase = torch.angle(noisy_stft)
    noisy_logmag = torch.log(1 + noisy_mag).unsqueeze(0).unsqueeze(0).to(device)

    with torch.no_grad():
        pred_logmag = model(noisy_logmag)
    pred_mag = torch.exp(pred_logmag.squeeze(0).squeeze(0)) - 1
    pred_stft = pred_mag * torch.exp(1j * noisy_phase)
    denoised_audio = torch.istft(pred_stft, n_fft=n_fft, hop_length=hop_stft, 
                                 window=torch.hann_window(n_fft), length=noisy_audio.size(1))
    return denoised_audio, sr

model = UNet().to(device)
model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
model.eval()

denoised_audio, sr = denoise_audio_full(model, noisy_file)
torchaudio.save('/kaggle/working/denoised_test.wav', denoised_audio.unsqueeze(0), sr)

noisy_audio, _ = torchaudio.load(noisy_file)

play_noisy_btn = Button(description="Play Noisy")
play_denoised_btn = Button(description="Play Denoised")
volume_slider = FloatSlider(value=1.0, min=0.0, max=1.0, step=0.1, description="Volume")

def play_noisy(b):
    display(Audio(noisy_audio.numpy(), rate=sr, normalize=False, volume=volume_slider.value))

def play_denoised(b):
    display(Audio(denoised_audio.numpy(), rate=sr, normalize=False, volume=volume_slider.value))

play_noisy_btn.on_click(play_noisy)
play_denoised_btn.on_click(play_denoised)

ui = VBox([
    HBox([play_noisy_btn, play_denoised_btn]),
    volume_slider
])

display(ui)