In [15]:
import torchaudio
import os
from multiprocessing import Pool, cpu_count
import math


from paths import *

In [12]:
torchaudio.info(os.path.join(wav_path, "s0101a.wav")).bits_per_sample

16

In [13]:
for wav in os.listdir(wav_path): 
    if torchaudio.info(os.path.join(wav_path, wav)).bits_per_sample != 16: 
        print(wav)

In [16]:
def collaboration_single_work(my_work_pool, my_det_dir): 
    print("Working from {} to {}".format(my_work_pool[0], my_work_pool[-1]))
    for rec_name in my_work_pool: 
        if torchaudio.info(os.path.join(my_det_dir, rec_name)).num_channels != 1: 
            print(wav)
    print("Work from {} to {} ends".format(my_work_pool[0], my_work_pool[-1]))

In [17]:
class MultiprocessManager: 
    def __init__(self, fun, my_det_dir, num_workers=4): 
        self.fun = fun
        self.my_det_dir = my_det_dir
        self.num_workers = num_workers
    
    def divide_work(self, work):
        # determine the number of items per worker
        items_per_worker = math.ceil(len(work) / self.num_workers)

        # divide the work into chunks
        work_chunks = [work[i:i + items_per_worker] for i in range(0, len(work), items_per_worker)]

        return work_chunks
    
    def collaboration_work(self): 
        flat_tasks = os.listdir(self.my_det_dir)
        task_pools = self.divide_work(flat_tasks)
        print(self.num_workers)
        p = Pool(self.num_workers)
        for i in range(self.num_workers):
            p.apply_async(self.fun, args=(task_pools[i], self.my_det_dir,))
        print('Waiting for all subprocesses done...')
        p.close()
        p.join()
        print('All subprocesses done.')

In [19]:
n_worker = cpu_count()
# random sampling
mpm = MultiprocessManager(collaboration_single_work, 
                          phone_seg_random_path, 
                          num_workers=n_worker)

mpm.collaboration_work()

32
Working from s2201a_00005265.wav to s3504a_00002769.wav
Working from s1804a_00000349.wav to s0802b_00004583.wav
Working from s0902a_00000601.wav to s1803a_00003720.wav
Working from s1903a_00003691.wav to s0302b_00003992.wav
Working from s2903a_00005266.wav to s2901a_00003195.wav
Working from s0902a_00005682.wav to s2401b_00005647.wav
Working from s1003a_00000756.wav to s0902b_00000782.wav
Working from s1302b_00000307.wav to s2802a_00003113.wav
Waiting for all subprocesses done...
Working from s3202b_00006973.wav to s1003a_00003467.wav
Working from s3803a_00003896.wav to s1101a_00004425.wav
Working from s0501a_00003348.wav to s0101a_00002667.wav
Working from s2001a_00001772.wav to s1602a_00002705.wav
Working from s2301b_00005249.wav to s2603a_00003034.wav
Working from s0302b_00003562.wav to s1202a_00003046.wav
Working from s0102b_00004778.wav to s0102b_00003632.wav
Working from s4003b_00000424.wav to s2702b_00005323.wav
Working from s1901a_00005127.wav to s1702a_00004682.wav
Working 

KeyboardInterrupt: 

In [25]:
import torch.nn as nn
import torch
import torchaudio
import os

from paths import *

In [26]:
REC_SAMPLE_RATE = 16000
N_FFT = 400

In [27]:
class MyTransform(nn.Module): 
    def __init__(self, sample_rate, n_fft): 
        super().__init__()
        self.transform = torchaudio.transforms.MelSpectrogram(sample_rate, n_fft=n_fft)
    
    def forward(self, waveform): 
        mel_spec = self.transform(waveform)
        mel_spec = mel_spec.squeeze()
        mel_spec = mel_spec.permute(1, 0) # (F, L) -> (L, F)
        return mel_spec

In [28]:
waveform, sample_rate = torchaudio.load(os.path.join(wav_path, "s0101a.wav"))

In [29]:
t = MyTransform(sample_rate=REC_SAMPLE_RATE, n_fft=N_FFT)



In [30]:
mel = t(waveform)

In [31]:
waveform.shape

torch.Size([1, 9969854])

In [32]:
mel

tensor([[0.0000e+00, 3.9037e-09, 2.1019e-08,  ..., 1.6256e-06, 7.4967e-06,
         9.0985e-06],
        [0.0000e+00, 2.2318e-08, 1.2016e-07,  ..., 2.3912e-06, 2.7047e-06,
         3.8634e-06],
        [0.0000e+00, 4.6319e-08, 2.4939e-07,  ..., 1.4409e-06, 2.3818e-06,
         2.8968e-06],
        ...,
        [0.0000e+00, 8.7722e-05, 4.7232e-04,  ..., 5.4108e-06, 2.6126e-06,
         1.5923e-06],
        [0.0000e+00, 1.5115e-04, 8.1386e-04,  ..., 2.8069e-06, 5.8884e-06,
         2.5544e-06],
        [0.0000e+00, 1.1784e-03, 6.3448e-03,  ..., 1.2396e-06, 2.0887e-06,
         2.2797e-06]])

In [33]:
torch.save(mel, "save.pt")