In [13]:
import sys

sys.path.insert(0, '..')

In [14]:
import json
from json import JSONDecodeError

import numpy as np

import torch
from torch.utils.data import DataLoader, TensorDataset
import torchaudio
from torchaudio.transforms import MelSpectrogram

from params import sample_rate, windowed_signal_length, num_mel_bands, overlap, hop_length

In [15]:
class MelSpecPipeline(torch.nn.Module):
    def __init__(self, n_fft=windowed_signal_length, sample_rate=sample_rate, n_mel=num_mel_bands):
        super().__init__()
        self.mel_spec = MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mel, power=2, center=False, hop_length=hop_length)

    def forward(self, wave):
        assert wave.shape[0] == 1

        mel_spec = self.mel_spec(wave)
        return mel_spec
    
pipeline = MelSpecPipeline()

In [16]:
def check_audio_metadata(metadata):
    assert metadata.sample_rate == 16000
    assert metadata.num_channels == 1
    assert metadata.num_frames > 0

def speechOverlap(mel_time_start, mel_time_end, speech_segments):
    for speech_start, speech_end in speech_segments:
        if speech_start < mel_time_end and mel_time_start < speech_end:
            return True
    return False

def createDataFromRecording(session_root, id):
    wav_path = session_root + "/session_" + str(id) + "_mixture.wav"
    json_path = session_root + "/session_" + str(id) + ".json"

    # check some metadata
    metadata = torchaudio.info(wav_path)
    check_audio_metadata(metadata)

    wave_len = (metadata.num_frames // metadata.sample_rate)
    # print(f'wave_len: {wave_len}')
    # print(f'Metadata: {metadata}')

    # retrieve speech segments
    speech_segments = set()
    
    # read+parse with diagnostics so we know which file fails
    with open(json_path, 'rb') as f:
        _raw = f.read()
    try:
        speech_info = json.loads(_raw.decode('utf-8'))
    except JSONDecodeError as _e:
        # provide filename, size and a short preview to help debugging
        print(f'Error decoding JSON at file: {json_path}')
        _preview = repr(_raw[:400])
        raise RuntimeError(f"Failed to parse JSON file {json_path!r} (size={len(_raw)} bytes). preview={_preview}") from _e
    for key in speech_info:
        if key.isdigit():
            for info in speech_info[key]:
                segment = (info["start"], info["stop"])
                assert segment[0] < segment[1]
                speech_segments.add(segment)
        
    # print(speech_segments)

    # MFSC pipeline
    wave, _ = torchaudio.load(wav_path)
    mels = pipeline(wave)
    mels.squeeze_(0)
    # librosa.display.specshow(mels.numpy())
    # print(f'Shape of mels: {mels.shape}')

    num_data = mels.shape[1] // (num_mel_bands // overlap)
    # print(f'NUM DATA: {num_data}')

    one_mel_length_time = (windowed_signal_length // overlap) / sample_rate

    X = []
    y = []

    for i in range(num_data - 1):
        mel_slice_start = i * (num_mel_bands // overlap)
        mel_slice_end = mel_slice_start + num_mel_bands

        mel_time_start = mel_slice_start * one_mel_length_time
        mel_time_end = mel_slice_end * one_mel_length_time

        X.append(mels[:, mel_slice_start : mel_slice_end].clone().detach())
        y.append(torch.ones(1) if speechOverlap(mel_time_start, mel_time_end, speech_segments) else torch.zeros(1))

        # if i == num_data // 1.02:
        #     print(f'mel_slice_start: {mel_slice_start}, mel_slice_end: {mel_slice_end}, mel_time_start: {mel_time_start}, mel_time_end: {mel_time_end}')
        #     print(f'y[-1] is {y[-1]}')

        assert X[-1].shape == (num_mel_bands, num_mel_bands)
        assert y[-1].shape == (1,)

    X = torch.stack(X, dim=0)
    y = torch.stack(y, dim=0)

    # sanity check ensuring our speech time amount is more or less accurate
    times_to_check = np.arange(start=0, stop=wave_len, step=wave_len / num_data)
    speech_times = [int(speechOverlap(times_to_check[i], times_to_check[i + 1], speech_segments)) for i in range(len(times_to_check) - 1)]

    speech_ratio_theory  = sum(speech_times) / len(speech_times)
    speech_ratio_data = ((y == 1).sum() / len(y))
    required_closeness_percentage = .05
    if abs(speech_ratio_theory - speech_ratio_data) > required_closeness_percentage:
        print(f'theoretical ratio of speech lables: {speech_ratio_theory}')
        print(f"This data's ratio of speech labels: {speech_ratio_data}")
        assert False

    assert all(isinstance(x, torch.Tensor) for x in X), "all X entries must be tensors"
    assert all(isinstance(t, torch.Tensor) for t in y), "all y entries must be tensors"

    return X, y

X, y = createDataFromRecording(session_root="LibriParty/dataset/train/session_0", id=0)

In [17]:
def createDataset(root, num_sessions):
    X = []
    y = []
    for session in range(num_sessions):
        session_X, session_y = createDataFromRecording(session_root=root + "session_" + str(session), id=session)
        assert torch.is_tensor(session_X) and torch.is_tensor(session_y) 
        assert session_X.shape[1] == num_mel_bands and session_X.shape[1] == num_mel_bands and session_y.shape[1] == 1 and session_X.shape[0] == session_y.shape[0]

        X.append(session_X)
        y.append(session_y)
    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)

    if X.dim() == 3:
        X = X.unsqueeze(1)
    y.squeeze_()
    X = X.to(torch.float32)
    y = y.to(torch.float32)

    print(f'X.shape is: {X.shape}')
    print(f'y.shape is: {y.shape}')

    assert X.shape[1] == 1 and X.shape[2] == num_mel_bands and X.shape[3] == num_mel_bands
    assert y.dim() == 1

    return TensorDataset(X, y)

In [18]:
valid_ds = createDataset(root='LibriParty/dataset/dev/', num_sessions=50)

X.shape is: torch.Size([45461, 1, 40, 40])
y.shape is: torch.Size([45461])


In [19]:
test_ds = createDataset(root='LibriParty/dataset/eval/', num_sessions=50)

X.shape is: torch.Size([45068, 1, 40, 40])
y.shape is: torch.Size([45068])


In [21]:
train_ds = createDataset(root='LibriParty/dataset/train/', num_sessions=250)

X.shape is: torch.Size([226699, 1, 40, 40])
y.shape is: torch.Size([226699])


In [22]:
torch.save(train_ds, "train_ds.pt")
torch.save(valid_ds, "valid_ds.pt")
torch.save(test_ds, "test_ds.pt")