In [2]:
import sys

sys.path.insert(0, '..')

In [3]:
from __future__ import annotations
import json

import torch
from torch.utils.data import DataLoader, TensorDataset
import torchaudio
from torchaudio.transforms import MelSpectrogram

from params import sample_rate, windowed_signal_length, num_mel_bands, overlap

In [4]:
class MelSpecPipeline(torch.nn.Module):
    def __init__(self, n_fft=windowed_signal_length, sample_rate=sample_rate, n_mel=num_mel_bands):
        super().__init__()
        self.mel_spec = MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mel, power=2)

    def forward(self, wave):
        assert wave.shape[0] == 1

        mel_spec = self.mel_spec(wave)
        return mel_spec
    
pipeline = MelSpecPipeline()

In [5]:
def check_audio_metadata(metadata):
    assert metadata.sample_rate == 16000
    assert metadata.num_channels == 1
    assert metadata.num_frames > 0

def createDataFromRecording(session_root, id):
    wav_path = session_root + "/session_" + str(id) + "_mixture.wav"
    json_path = session_root + "/session_" + str(id) + ".json"

    # check some metadata
    metadata = torchaudio.info(wav_path)
    check_audio_metadata(metadata)
    print(f'Metadata: {metadata}')

    # retrieve speech segments
    speech_segments = set()
    with open(json_path, 'r') as f:
        speech_info = json.load(f)
    for key in speech_info:
        if key.isdigit():
            for info in speech_info[key]:
                segment = (info["start"], info["stop"])
                assert segment[0] < segment[1]
                speech_segments.add(segment)
    print(speech_segments)

    # MFSC pipeline
    wave, _ = torchaudio.load(wav_path)
    mels = pipeline(wave)
    mels.squeeze_(0)
    # librosa.display.specshow(mels.numpy())
    print(f'Shape of mels: {mels.shape}')

    mel_length_time = (windowed_signal_length * num_mel_bands) / sample_rate
    print(f'mel_length_time: {mel_length_time}')

    num_data = mels.shape[1] // num_mel_bands * overlap
    print(f'num_data: {num_data}')

    X = [torch.ones(num_mel_bands, num_mel_bands) for _ in range(num_data)]
    y = [torch.zeros(1) for _  in range(num_data)]

    # we iterate over all complete mel bands possible. 
    # We drop the end bit of the recording if it isn't long enough for a whole MFSC spectrogram
    # We process with overlaps of 1/`overlap`
    for output_i, mel_i in enumerate(range(0, int(mels.shape[1] - num_mel_bands), num_mel_bands // overlap)):
        mel_start_time = mel_i * mel_length_time
        mel_end_time = (mel_i + 1) * mel_length_time

        for start, end in speech_segments:
            if mel_start_time < end and start < mel_end_time:
                y[output_i] = torch.ones(1)
        X[output_i] = mels[:, mel_i : mel_i + num_mel_bands].clone().detach()
        # print(f'X[{output_i}].shape: {X[output_i].shape}')
        # print(f'y shape: {y[output_i].shape}')
        assert X[output_i].shape == (num_mel_bands, num_mel_bands)
        assert y[output_i].shape == (1,)

    assert all(isinstance(x, torch.Tensor) for x in X), "all X entries must be tensors"
    assert all(x.shape == (num_mel_bands, num_mel_bands) for x in X), "inconsistent X shapes"
    assert all(isinstance(t, torch.Tensor) for t in y), "all y entries must be tensors"

    X = torch.stack(X, dim=0)
    y = torch.stack(y, dim=0)

    print(f'X.shape is {X.shape}')
    print(f'y.shape is {y.shape}')
    return X, y

X, y = createDataFromRecording(session_root="LibriParty/dataset/train/session_0", id=0)

Metadata: AudioMetaData(sample_rate=16000, num_frames=4783692, num_channels=1, bits_per_sample=32, encoding=PCM_F)
{(124.63, 136.975), (72.876, 82.036), (194.072, 205.542), (158.155, 162.49), (0.582, 16.477), (233.761, 246.866), (14.198, 25.438), (262.443, 270.543), (90.981, 106.051), (273.033, 288.833), (52.268, 68.272), (137.211, 151.031), (250.985, 265.71), (273.62, 287.805), (123.585, 134.32), (98.198, 112.692), (207.596, 223.196), (208.436, 220.976), (234.433, 249.023)}
Shape of mels: torch.Size([40, 18687])
mel_length_time: 1.28
num_data: 934
X.shape is torch.Size([934, 40, 40])
y.shape is torch.Size([934, 1])


In [8]:
def createDataset(root, num_sessions):
    X = []
    y = []
    for session in range(num_sessions):
        session_X, session_y = createDataFromRecording(session_root=root + "session_" + str(session), id=session)
        assert torch.is_tensor(session_X) and torch.is_tensor(session_y) 
        assert session_X.shape[1] == num_mel_bands and session_X.shape[1] == num_mel_bands and session_y.shape[1] == 1 and session_X.shape[0] == session_y.shape[0]

        X.append(session_X)
        y.append(session_y)
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)

    print(f'X.shape is: {X.shape}')
    print(f'y.shape is: {y.shape}')

createDataset(root='LibriParty/dataset/dev/', num_sessions=50)

Metadata: AudioMetaData(sample_rate=16000, num_frames=4721080, num_channels=1, bits_per_sample=32, encoding=PCM_F)
{(80.508, 86.318), (292.097, 295.067), (107.296, 112.546), (6.534, 8.979), (144.639, 149.959), (22.107, 33.907), (204.186, 211.701), (11.362, 15.032), (200.045, 203.305), (71.537, 83.347), (186.035, 191.6), (0.32, 3.755), (288.428, 291.808), (104.604, 116.529), (69.919, 74.559), (161.507, 163.797), (281.816, 284.891), (38.728, 41.183), (203.426, 216.331), (66.846, 70.546), (120.617, 123.707), (85.125, 89.15)}
Shape of mels: torch.Size([40, 18442])
mel_length_time: 1.28
num_data: 922
X.shape is torch.Size([922, 40, 40])
y.shape is torch.Size([922, 1])
Metadata: AudioMetaData(sample_rate=16000, num_frames=4353202, num_channels=1, bits_per_sample=32, encoding=PCM_F)
{(14.381, 24.836), (218.954, 221.939), (176.612, 183.372), (78.457, 82.312), (143.88, 148.945), (268.635, 272.075), (96.389, 103.999), (251.115, 258.56), (132.073, 135.093), (246.278, 253.308), (100.741, 104.906),