In [None]:
from sound_field import SoundField
from signal_info import signal_info
from optimizer import optimizer
from optimizer_v2 import optimizer_v2
from DoA_est import DoA_via_bands
import utils
import numpy as np
import math
import time
import torch
from typing import List,Union
import os
from collections import defaultdict,Counter
import matplotlib.pyplot as plt
import random

LEBEDEV = 'lebedev'
POINTS_162 = '162_points'

In [9]:
dataset_path = r"data\WSJ0"

In [10]:
def generate_random_doas(num_speakers : int ,min_theta : float = 0,max_theta : float = 90,min_phi : float = -180,max_phi : float= 180,min_dist : float = 15):
    doas = list()
    while len(doas) < num_speakers:
        random_doa = np.array([random.uniform(min_theta,max_theta),random.uniform(min_phi,max_phi)])
        if any([np.sqrt(np.sum((random_doa - doa)**2)) < min_dist for doa in doas]):
            continue
        doas.append(random_doa)
    return doas

In [None]:
def create_dataset(
    wave_files_path: List[str],
    input_order: int = 1,
    upscaled_order: int = 3,
    min_num_speakers: int = 1,
    max_num_speakers: int = 4,
    min_theta: float = 0,
    max_theta: float = 90,
    min_phi: float = -180,
    max_phi: float = 180,
    min_dist: float = 15,
    sr : int = 16000,
    grid_type : str = POINTS_162
):
    config_dict = locals()
    total_num_wav_files = len(wave_files_path)
    num_wave_files_used = 0

    data_set = list()
    while num_wave_files_used < total_num_wav_files:
        num_speakers = min(
            random.randint(min_num_speakers, max_num_speakers), total_num_wav_files - num_wave_files_used
        )
        doas = generate_random_doas(
            num_speakers,
            min_theta=min_theta,
            max_theta=max_theta,
            min_phi=min_phi,
            max_phi=max_phi,
            min_dist=min_dist,
        )
        signals = [
            signal_info(
                signal_path=wave_files_path[num_wave_files_used + i],
                th=doa[0],
                ph=doa[1],
            )
            for i, doa in enumerate(doas)
        ]
        sound_field = SoundField()
        sound_field.anm_t = sound_field.create(
            signals=signals,
            order=input_order,
            debug=False,
            grid_type=grid_type,
            sr=sr
        )
        sound_field.gt_anm_t_upscaled = sound_field.create(
            signals=signals,
            order=upscaled_order,
            debug=False,
            grid_type=grid_type,
            sr=sr
        )
        data_set.append(sound_field)
        num_wave_files_used += num_speakers
    return data_set,config_dict

In [16]:
def get_wav_files_in_folder(path,shuffle=True):
    wav_files = list()
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".wav"):
                wav_files.append(os.path.join(root, file))
    if shuffle:
        random.shuffle(wav_files)
    return wav_files


In [17]:
dataset_name = os.path.basename(dataset_path)
dataset = dict()
for folder in os.listdir(dataset_path):
    if not(os.path.isdir(os.path.join(dataset_path,folder))):
        continue
    dataset[folder] = dict()
    wav_files = get_wav_files_in_folder(os.path.join(dataset_path,folder))
    print(f"Creating {folder} dataset for {dataset_name}")
    dataset[folder]['data'],dataset[folder]['config'] = create_dataset(wav_files)
    print(f"Complete - Size : {len(dataset[folder]['data'])}")


Creating test dataset for WSJ0
Complete - Size : 24
Creating train dataset for WSJ0
Complete - Size : 2284
Creating validation dataset for WSJ0
Complete - Size : 139


In [19]:
torch.save(dataset,r'data\WSJ0\Dataset_03_01_25.pt')

In [2]:
torch.load(r'data\WSJ0\train_data_set.pt')

  torch.load(r'data\WSJ0\train_data_set.pt')


[<sound_field.SoundField at 0x1d183a4ef90>,
 <sound_field.SoundField at 0x1d183a4f9e0>,
 <sound_field.SoundField at 0x1d183a84110>,
 <sound_field.SoundField at 0x1d183a84680>,
 <sound_field.SoundField at 0x1d183a84d70>,
 <sound_field.SoundField at 0x1d183a85190>,
 <sound_field.SoundField at 0x1d182c3d760>,
 <sound_field.SoundField at 0x1d183a4d010>,
 <sound_field.SoundField at 0x1d183a86060>,
 <sound_field.SoundField at 0x1d183a86570>,
 <sound_field.SoundField at 0x1d1829a9790>,
 <sound_field.SoundField at 0x1d183a86e10>,
 <sound_field.SoundField at 0x1d183a87290>,
 <sound_field.SoundField at 0x1d183a878c0>,
 <sound_field.SoundField at 0x1d183a87e60>,
 <sound_field.SoundField at 0x1d183a98350>,
 <sound_field.SoundField at 0x1d183a98740>,
 <sound_field.SoundField at 0x1d183a98e90>,
 <sound_field.SoundField at 0x1d183a996d0>,
 <sound_field.SoundField at 0x1d183a99e50>,
 <sound_field.SoundField at 0x1d183a9a420>,
 <sound_field.SoundField at 0x1d181b7aa20>,
 <sound_field.SoundField at 0x1d