# Imports

In [1]:
import os
import pytorch_lightning as pl
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
import subprocess
from urllib import request
import zipfile
from torch.utils.data import Dataset, IterableDataset, DataLoader
from dataclasses import dataclass
from progress_bar import ProgressBar 

# Defentions

In [2]:
data_dir = 'data'

# Data Classes

## Data Index

In [3]:
@dataclass
class DataIndex:
    url: str
    id: str

    @property
    def dir(self):
        return self.url.split('/')[-1].split('.')[0]

    def list_files(self, data_dir: str):
        return sorted([os.path.join(data_dir, self.dir, f) for f in os.listdir(os.path.join(data_dir, self.dir)) if
                       f.endswith('.wav')])

In [4]:
@dataclass
class XyDataIndexPair:
    X_data_index: DataIndex
    y_data_index: DataIndex
    stage: str

    def list_file_pairs(self, data_dir: str):
        X_files = self.X_data_index.list_files(data_dir)
        y_files = self.y_data_index.list_files(data_dir)
        file_pair_lst = list(zip(X_files, y_files))
        for X_file, y_file in file_pair_lst:
            assert os.path.basename(X_file) == os.path.basename(y_file), f'X: {X_file} != y: {y_file}'
        return file_pair_lst

## Dataset

In [5]:
class AudioDataset(IterableDataset):
    def __init__(self, Xy_data_index_pair: XyDataIndexPair, data_dir=data_dir, window_size=16384,
                 window_size_overlap_percentage=0.5, target_sample_rate=16000):
        """
        Args:
            directory (string): Directory with all the audio files.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = data_dir
        self.Xy_data_index_pair = Xy_data_index_pair
        self.Xy_file_pairs = Xy_data_index_pair.list_file_pairs(data_dir)
        self.target_sample_rate = target_sample_rate
        self.window_size = window_size
        self.window_overlap = int(window_size_overlap_percentage * window_size)

    def _read_audio(self, file):
        waveform, sample_rate = torchaudio.load(file)
        assert sample_rate == self.target_sample_rate, f'Expected sample rate of {self.target_sample_rate} but got {sample_rate} for file {file}'
        return waveform.squeeze(0), sample_rate

    def __iter__(self):
        for X_file, y_file in self.Xy_file_pairs:
            X_waveform, X_sample_rate = self._read_audio(X_file)
            y_waveform, y_sample_rate = self._read_audio(y_file)
            assert X_waveform.shape[-1] == y_waveform.shape[
                -1], f'Expected same number of samples but got {X_waveform.shape[-1]} and {y_waveform.shape[-1]} for files {X_file} and {y_file}'
            for i in range(0, X_waveform.shape[-1], self.window_size - self.window_overlap):
                X_waveform_window = X_waveform[..., i:i + self.window_size]
                y_waveform_window = y_waveform[..., i:i + self.window_size]
                yield X_waveform_window, y_waveform_window

## DataLoader

In [6]:
class AudioDataModule(pl.LightningDataModule):
    def __init__(self, data_dir=data_dir, sample_rate=16000):
        super().__init__()
        self.data_dir = data_dir
        self.sample_rate = sample_rate
        self.bitstream_prefix = 'http://datashare.is.ed.ac.uk/bitstream/handle/10283/1942'

        self.Xy_data_index_lst = [
            XyDataIndexPair(
                X_data_index=DataIndex(url=f'{self.bitstream_prefix}/noisy_trainset_wav.zip', id='noisy'),
                y_data_index=DataIndex(url=f'{self.bitstream_prefix}/clean_trainset_wav.zip', id='clean'),
                stage='train',
            ),
            XyDataIndexPair(
                X_data_index=DataIndex(url=f'{self.bitstream_prefix}/noisy_testset_wav.zip', id='noisy'),
                y_data_index=DataIndex(url=f'{self.bitstream_prefix}/clean_testset_wav.zip', id='clean'),
                stage='val',
            ),
        ]

    def prepare_data(self):
        '''
        Download and extract data
        '''
        for Xy_data_index_pair in tqdm(self.Xy_data_index_lst):
            for data_index in [Xy_data_index_pair.X_data_index, Xy_data_index_pair.y_data_index]:
                self.download_and_extract(data_index.url, os.path.join(self.data_dir, data_index.dir))
                self.convert_wavs(os.path.join(self.data_dir, data_index.dir),
                                  os.path.join(self.data_dir, data_index.dir))

    def setup(self, stage=None):
        '''
        Create train, val, test datasets and dataloaders
        :param stage: in fit, test, predict
        :return: None
        '''
        find_index_pair_by_stage = lambda stage: \
            [index_pair for index_pair in self.Xy_data_index_lst if index_pair.stage == stage][0]

        valid_index_pair = find_index_pair_by_stage('val')
        valid_dataset = AudioDataset(valid_index_pair, data_dir=self.data_dir,
                                     target_sample_rate=self.sample_rate, window_size_overlap_percentage=0.0)

        if stage == 'fit' or stage is None:
            train_index_pair = find_index_pair_by_stage('train')
            self.train_dataset = AudioDataset(train_index_pair, data_dir=self.data_dir,
                                              target_sample_rate=self.sample_rate, window_size_overlap_percentage=0.5)
            self.valid_dataset = valid_dataset
        if stage == 'test' or stage == 'predict':
            self.predict_dataset = valid_dataset
            self.test_dataset = valid_dataset

    def download_and_extract(self, url, extract_to):
        zip_path = f'{extract_to}.zip'
        if not os.path.exists(zip_path):
            print(f'DOWNLOADING DATASET FROM {url}...')
            zip_dir = os.path.dirname(zip_path)
            if not os.path.exists(zip_dir):
                os.makedirs(zip_dir)
            request.urlretrieve(url, zip_path, ProgressBar())
        if not os.path.exists(extract_to):
            print(f'INFLATING ZIP FROM {zip_path} TO {extract_to} ...')
            os.makedirs(extract_to)
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(extract_to)

    def convert_wavs(self, source_dir, target_dir):
        print(f'CONVERTING WAVS FROM {source_dir} TO {target_dir}')
        if not os.path.exists(target_dir):
            os.makedirs(target_dir)
            for wav_file in os.listdir(source_dir):
                if wav_file.endswith('.wav'):
                    source_path = os.path.join(source_dir, wav_file)
                    target_path = os.path.join(target_dir, wav_file)
                    subprocess.run(['sox', source_path, '-r', self.sample_rate, target_path], check=True)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=400, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.valid_dataset, batch_size=400, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=400, shuffle=False)

## Create datamodule inst

In [7]:
datamodule = AudioDataModule(data_dir=data_dir)

In [8]:
datamodule.prepare_data()

100%|██████████| 2/2 [00:00<00:00, 2003.49it/s]

CONVERTING WAVS FROM data\noisy_trainset_wav TO data\noisy_trainset_wav
CONVERTING WAVS FROM data\clean_trainset_wav TO data\clean_trainset_wav
CONVERTING WAVS FROM data\noisy_testset_wav TO data\noisy_testset_wav
CONVERTING WAVS FROM data\clean_testset_wav TO data\clean_testset_wav





In [9]:
datamodule.setup(stage='fit')

In [10]:
X,y = next(iter(datamodule.train_dataset))

In [12]:
X.shape,y.shape

(torch.Size([16384]), torch.Size([16384]))