In [3]:
from sound_field import SoundField,divide_to_subbands,divide_to_time_windows,create_grid
from signal_info import signal_info
from optimizer import optimizer
from optimizer_v2 import optimizer_v2
from DoA_est import DoA_via_bands
import utils
import numpy as np
import math
import time
import torch
from typing import List,Union
import os
from collections import defaultdict,Counter
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
import multiprocessing as mp
import inspect



LEBEDEV = 'lebedev'
POINTS_162 = '162_points'

In [4]:
dataset_path = r"data\WSJ0"

In [5]:
def generate_random_doas(num_speakers : int ,min_theta : float = 0,max_theta : float = 90,min_phi : float = -180,max_phi : float= 180,min_dist : float = 15):
    doas = list()
    while len(doas) < num_speakers:
        random_doa = np.array([random.uniform(min_theta,max_theta),random.uniform(min_phi,max_phi)])
        if any([np.sqrt(np.sum((random_doa - doa)**2)) < min_dist for doa in doas]):
            continue
        doas.append(random_doa)
    return doas

In [6]:
def create_sound_field_wrapper(speaker_wavs,config):
    doas = generate_random_doas(
            len(speaker_wavs),
            min_theta=config['min_theta'],
            max_theta=config['max_theta'],
            min_phi=config['min_phi'],
            max_phi=config['max_phi'],
            min_dist=config['min_dist'],
        )
    signals = [
        signal_info(
            signal_path=speaker_wavs[i],
            th=doa[0],
            ph=doa[1],
        )
        for i, doa in enumerate(doas)
    ]
    sound_field = SoundField()
    anm_t = sound_field.create(
        signals=signals,
        order=config['input_order'],
        debug=False,
        grid_type=config['grid_type'],
        sr=config['sr'],
    )
    # anm_t_subbands = divide_to_subbands(anm_t=anm_t, num_bins=config['num_bins'], sr=sound_field.sr)
    # anm_t_subbands_windowed = divide_to_time_windows(anm_t_subbands=anm_t_subbands,window_length=config['window_length'], max_num_windows=config['max_num_windows'])
    
    gt_anm_t_upscaled = sound_field.create(
        signals=signals,
        order=config['upscaled_order'],
        debug=False,
        grid_type=config['grid_type'],
        sr=config['sr'],
    )
    # gt_anm_t_upscaled_subbands = divide_to_subbands(anm_t=gt_anm_t_upscaled, num_bins=config['num_bins'], sr=sound_field.sr)
    # gt_anm_t_upscaled_subbands_windowed = divide_to_time_windows(anm_t_subbands=gt_anm_t_upscaled_subbands,window_length=config['window_length'], max_num_windows=config['max_num_windows'])
    return anm_t,gt_anm_t_upscaled

In [10]:
def create_dataset(
    wav_files_path: List[str],
    input_order: int = 1,
    upscaled_order: int = 3,
    num_bins : int = 45,
    window_length : int = 1024,
    max_num_windows : int = 10,
    min_num_speakers: int = 1,
    max_num_speakers: int = 4,
    min_theta: float = 0,
    max_theta: float = 90,
    min_phi: float = -180,
    max_phi: float = 180,
    min_dist: float = 15,
    sr : int = 16000,
    grid_type : str = POINTS_162,
):
    # Get all input arguments as a dictionary
    frame = inspect.currentframe()
    args, _, _, values = inspect.getargvalues(frame)
    config_dict = {arg: values[arg] for arg in args if arg != 'wav_files_path'}

    total_num_wav_files = len(wav_files_path)
    num_wav_files_used = 0
    speaker_groups = list()
    
    while num_wav_files_used < total_num_wav_files:
        num_speakers = min(
            random.randint(min_num_speakers, max_num_speakers), total_num_wav_files - num_wav_files_used
        )
        speaker_groups.append(wav_files_path[num_wav_files_used:num_wav_files_used+num_speakers])
        num_wav_files_used += num_speakers

    data_set = list()

    for group in tqdm(speaker_groups):
        try:
            data_set.append(create_sound_field_wrapper(group,config_dict))
        except Exception as e:
            print(f"Error occured while processing group {group} : {e}")

    
    # with mp.Pool(processes=mp.cpu_count()) as pool:
    #     with tqdm(total=len(speaker_groups)) as pbar:
    #         for result in pool.imap(create_sound_field_wrapper, [(2,1) for group in speaker_groups]):
    #             data_set.append(result)
    #             pbar.update()  # Update the progress bar for each completed task
    config_dict['P_th'],config_dict['P_ph'],_ = create_grid(grid_type)
    return data_set,config_dict

In [8]:
def get_wav_files_in_folder(path,shuffle=True):
    wav_files = list()
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".wav"):
                wav_files.append(os.path.join(root, file))
    if shuffle:
        random.shuffle(wav_files)
    return wav_files


In [11]:
dataset_name = os.path.basename(dataset_path)
dataset = dict()
for folder in os.listdir(dataset_path):
    if not(os.path.isdir(os.path.join(dataset_path,folder))):
        continue
    dataset[folder] = dict()
    wav_files = get_wav_files_in_folder(os.path.join(dataset_path,folder))
    print(f"Creating {folder} dataset for {dataset_name}")
    
    dataset[folder]['data'],dataset[folder]['config'] = create_dataset(wav_files)
    print(f"Complete - Size : {len(dataset[folder]['data'])}")


Creating test dataset for WSJ0


100%|██████████| 25/25 [00:00<00:00, 59.27it/s]


Complete - Size : 25
Creating train dataset for WSJ0


100%|██████████| 2312/2312 [00:43<00:00, 53.22it/s]


Complete - Size : 2312
Creating validation dataset for WSJ0


100%|██████████| 144/144 [00:02<00:00, 53.36it/s]

Complete - Size : 144





In [12]:
torch.save(dataset,r'data\WSJ0\Dataset_03_01_25.pt')

In [22]:
data = torch.load(r'data\WSJ0\Dataset_testv2.pt')

  data = torch.load(r'data\WSJ0\Dataset_testv2.pt')


In [14]:
data.keys()

dict_keys(['data', 'config'])

In [19]:
data['data'][0][0]

tensor([[[[-1.6131e-05,  1.0823e-05, -1.9621e-05,  1.3109e-06],
          [-1.6112e-05,  1.0818e-05, -1.9604e-05,  1.3114e-06],
          [-1.6093e-05,  1.0812e-05, -1.9587e-05,  1.3121e-06],
          ...,
          [ 3.2435e-07, -3.2230e-05, -9.2528e-06,  1.5634e-05],
          [ 2.4221e-07, -3.2170e-05, -9.2901e-06,  1.5600e-05],
          [ 1.6003e-07, -3.2109e-05, -9.3273e-06,  1.5565e-05]],

         [[-1.1732e-05,  6.9569e-06, -3.8003e-06, -5.3970e-06],
          [-1.1867e-05,  7.0557e-06, -3.8958e-06, -5.4317e-06],
          [-1.2001e-05,  7.1544e-06, -3.9910e-06, -5.4664e-06],
          ...,
          [-2.8753e-07, -2.1726e-05, -5.0344e-06,  9.5308e-06],
          [-4.4787e-07, -2.1673e-05, -5.1537e-06,  9.5122e-06],
          [-6.0784e-07, -2.1621e-05, -5.2725e-06,  9.4936e-06]],

         [[-4.4520e-06,  1.2950e-06, -8.2675e-06,  2.6715e-06],
          [-4.5399e-06,  1.3537e-06, -8.3241e-06,  2.6469e-06],
          [-4.6262e-06,  1.4118e-06, -8.3780e-06,  2.6215e-06],
      