In [16]:
import torch
import os
import librosa
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [17]:
NOISY_DIR = r"C:\Users\sansk\Downloads\ai_noise_reduction\data\noisy_dataset_wav"
CLEAN_DIR = r"C:\Users\sansk\Downloads\ai_noise_reduction\data\clean_testset_wav"

In [18]:
class NoiseReductionDataset(Dataset):
    def __init__(self, noisy_dir, clean_dir, sr=16000, target_length=50000):
        self.noisy_dir = noisy_dir
        self.clean_dir = clean_dir
        self.sr = sr
        self.target_length = target_length
        self.noisy_files = sorted([f for f in os.listdir(noisy_dir) if f.endswith('.wav')])
        self.clean_files = sorted([f for f in os.listdir(clean_dir) if f.endswith('.wav')])
        print(f"Found {len(self.noisy_files)} noisy files and {len(self.clean_files)} clean files.")

    def __len__(self):
        return min(len(self.noisy_files), len(self.clean_files))

    def __getitem__(self, idx):
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[idx])
        clean_path = os.path.join(self.clean_dir, self.clean_files[idx])

        # Load audio files
        noisy = self.load_audio(noisy_path)
        clean = self.load_audio(clean_path)

        # Pad audio to the target length
        noisy = F.pad(torch.tensor(noisy), (0, self.target_length - len(noisy)), 'constant', 0)
        clean = F.pad(torch.tensor(clean), (0, self.target_length - len(clean)), 'constant', 0)

        return noisy, clean

    def load_audio(self, filepath):
        # Using librosa to load the audio file
        audio, _ = librosa.load(filepath, sr=self.sr)
        return audio

In [19]:
# Initialize the dataset with a target length
dataset = NoiseReductionDataset(NOISY_DIR, CLEAN_DIR, target_length=50000)

Found 824 noisy files and 824 clean files.


In [20]:
# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [21]:
# Preview a batch
for noisy_batch, clean_batch in dataloader:
    print("Noisy batch shape:", noisy_batch.shape)
    print("Clean batch shape:", clean_batch.shape)
    break

Noisy batch shape: torch.Size([4, 50000])
Clean batch shape: torch.Size([4, 50000])
