#### Testing dataset

In [1]:
import random
from glob import glob
import numpy as np
import torch as th
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from typing import Optional
import argparse

from utils.funtctional import handle_scp
from utils.data_processing import read_wav
from utils.load_config import load_config 

EPS = 1e-8

In [6]:
cfg, ckpt_folder = load_config('./config/train_rnn.yml')
cfg['data']

{'data_mixed': './mixed_data/',
 'data_targets': './targets/',
 'train_percent': 0.8,
 'valid_percent': 0.1,
 'test_percent': 0.1,
 'num_workers': 1,
 'batch_size': 1,
 'sample_rate': 16000,
 'chunk_size': 32000,
 'least_size': 16000}

In [7]:
class AudioDataset(Dataset):
    def __init__(self, mix_scp: str, ref_scp: list, sample_rate: int = 8000, chunk_size: int = 32000, least_size: int = 16000):
        self.sample_rate = sample_rate
        self.chunk_size = chunk_size
        self.least_size = least_size

        self.mix_audio = self._load_audio(mix_scp)
        self.ref_audio = [self._load_audio(ref) for ref in ref_scp]

    def _load_audio(self, scp_path):
        index_dict = handle_scp(scp_path)
        audio_data = []
        for key in index_dict.keys():
            src, sr = read_wav(index_dict[key], return_rate=True)
            if sr != self.sample_rate:
                raise RuntimeError(f"Sample rate mismatch: {sr} vs {self.sample_rate}")
            audio_data.append(src)
        return audio_data

    def __len__(self):
        return len(self.mix_audio)

    def __getitem__(self, idx):
        mix_sample = self.mix_audio[idx]
        ref_samples = [ref[idx] for ref in self.ref_audio]
        return mix_sample, ref_samples


class AudioDataModule(pl.LightningDataModule):
    def __init__(self, train_scp, val_scp, ref_scp_train, ref_scp_val, batch_size=128, sample_rate=8000, chunk_size=32000, 
                 least_size=16000, num_workers=4, pin_memory=False, seed=42):
        super().__init__()
        self.train_scp = train_scp
        self.val_scp = val_scp
        self.ref_scp_train = ref_scp_train
        self.ref_scp_val = ref_scp_val
        self.batch_size = batch_size
        self.sample_rate = sample_rate
        self.chunk_size = chunk_size
        self.least_size = least_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.seed = seed

        self._set_seed(seed)

    def setup(self, stage: Optional[str] = None):
        self.train_dataset = AudioDataset(self.train_scp, self.ref_scp_train, sample_rate=self.sample_rate, 
                                          chunk_size=self.chunk_size, least_size=self.least_size)
        self.val_dataset = AudioDataset(self.val_scp, self.ref_scp_val, sample_rate=self.sample_rate, 
                                        chunk_size=self.chunk_size, least_size=self.least_size)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, 
                          pin_memory=self.pin_memory, worker_init_fn=self.seed_worker)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, 
                          pin_memory=self.pin_memory, worker_init_fn=self.seed_worker)

    def _set_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        th.manual_seed(seed)
        th.cuda.manual_seed_all(seed)

    def seed_worker(self, worker_id):
        worker_seed = th.initial_seed() % 2**32
        np.random.seed(worker_seed)
        random.seed(worker_seed)

In [None]:
datamodule = AudioDataModule(**cfg['data'])

In [None]:
train_dataloader, val_dataloader = make_dataloader(**cfg['data'])
dataloaders = {'train': train_dataloader, 'valid': val_dataloader}