In [2]:
import sys

sys.path.insert(0, '..')

In [3]:
import json

import torch
from torch.utils.data import DataLoader, TensorDataset
import torchaudio
from torchaudio.transforms import MelSpectrogram

from params import sample_rate, windowed_signal_length, num_mel_bands, overlap, hop_length

In [4]:
class MelSpecPipeline(torch.nn.Module):
    def __init__(self, n_fft=windowed_signal_length, sample_rate=sample_rate, n_mel=num_mel_bands):
        super().__init__()
        self.mel_spec = MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mel, power=2, center=False, hop_length=hop_length)

    def forward(self, wave):
        assert wave.shape[0] == 1

        mel_spec = self.mel_spec(wave)
        return mel_spec
    
pipeline = MelSpecPipeline()

In [24]:
def check_audio_metadata(metadata):
    assert metadata.sample_rate == 16000
    assert metadata.num_channels == 1
    assert metadata.num_frames > 0

def speechOverlap(mel_time_start, mel_time_end, speech_segments):
    for speech_start, speech_end in speech_segments:
        if speech_start < mel_time_end and mel_time_start < speech_end:
            return True
    return False

def createDataFromRecording(session_root, id):
    wav_path = session_root + "/session_" + str(id) + "_mixture.wav"
    json_path = session_root + "/session_" + str(id) + ".json"

    # check some metadata
    metadata = torchaudio.info(wav_path)
    check_audio_metadata(metadata)
    # print(f'Metadata: {metadata}')

    # retrieve speech segments
    speech_segments = set()
    with open(json_path, 'r') as f:
        speech_info = json.load(f)
    for key in speech_info:
        if key.isdigit():
            for info in speech_info[key]:
                segment = (info["start"], info["stop"])
                assert segment[0] < segment[1]
                speech_segments.add(segment)
    print(speech_segments)

    covered = sum([end - start for start, end in speech_segments])
    print(f'covered length is {covered}, percentage of recording: {covered / 500}')

    # MFSC pipeline
    wave, _ = torchaudio.load(wav_path)
    mels = pipeline(wave)
    mels.squeeze_(0)
    # librosa.display.specshow(mels.numpy())
    print(f'Shape of mels: {mels.shape}')
    num_data = mels.shape[1] // (num_mel_bands // overlap)

    mel_length_time = ((num_mel_bands - 1) * hop_length + windowed_signal_length) / sample_rate
    print(f'mel_length_time: {mel_length_time}')

    X = []
    y = []

    for i in range(num_data - 1):
        mel_slice_start = i * (num_mel_bands // overlap)
        mel_slice_end = mel_slice_start + num_mel_bands

        mel_time_start = mel_slice_start * mel_length_time
        mel_time_end = mel_time_start + mel_length_time

        X.append(mels[:, mel_slice_start : mel_slice_end].clone().detach())
        y.append(torch.ones(1) if speechOverlap(mel_time_start, mel_time_end, speech_segments) else torch.zeros(1))

        if i == num_data - 2:
            # print the 4 mel values above
            print(f'mel_slice_start: {mel_slice_start}, mel_slice_end: {mel_slice_end}, mel_time_start: {mel_time_start}, mel_time_end: {mel_time_end}')

    X = torch.cat(X, dim=0)
    y = torch.cat(y, dim=0)
    return X, y

X, y = createDataFromRecording(session_root="LibriParty/dataset/train/session_0", id=0)
print(f'ratio of speech to non-speech labels: {((y==1).sum() / (y==0).sum() * 100):.2f}%')

{(124.63, 136.975), (72.876, 82.036), (194.072, 205.542), (158.155, 162.49), (0.582, 16.477), (233.761, 246.866), (14.198, 25.438), (262.443, 270.543), (90.981, 106.051), (273.033, 288.833), (52.268, 68.272), (137.211, 151.031), (250.985, 265.71), (273.62, 287.805), (123.585, 134.32), (98.198, 112.692), (207.596, 223.196), (208.436, 220.976), (234.433, 249.023)}
covered length is 243.21300000000002, percentage of recording: 0.486426
Shape of mels: torch.Size([40, 18685])
mel_length_time: 0.656
mel_slice_start: 18640, mel_slice_end: 18680, mel_time_start: 12227.84, mel_time_end: 12228.496000000001
ratio of speech to non-speech labels: 1.74%


In [None]:
# def createDataset(root, num_sessions):
#     X = []
#     y = []
#     for session in range(num_sessions):
#         session_X, session_y = createDataFromRecording(session_root=root + "session_" + str(session), id=session)
#         assert torch.is_tensor(session_X) and torch.is_tensor(session_y) 
#         assert session_X.shape[1] == num_mel_bands and session_X.shape[1] == num_mel_bands and session_y.shape[1] == 1 and session_X.shape[0] == session_y.shape[0]

#         X.append(session_X)
#         y.append(session_y)
#     X = torch.cat(X, dim=0)
#     y = torch.cat(y, dim=0)

#     print(f'X.shape is: {X.shape}')
#     print(f'y.shape is: {y.shape}')
#     return X, y

# X, y = createDataset(root='LibriParty/dataset/dev/', num_sessions=50)

{(80.508, 86.318), (292.097, 295.067), (107.296, 112.546), (6.534, 8.979), (144.639, 149.959), (22.107, 33.907), (204.186, 211.701), (11.362, 15.032), (200.045, 203.305), (71.537, 83.347), (186.035, 191.6), (0.32, 3.755), (288.428, 291.808), (104.604, 116.529), (69.919, 74.559), (161.507, 163.797), (281.816, 284.891), (38.728, 41.183), (203.426, 216.331), (66.846, 70.546), (120.617, 123.707), (85.125, 89.15)}
Shape of mels: torch.Size([40, 18440])
mel_length_time: 0.656


  X.append(torch.tensor(mels[:, mel_slice_start : mel_slice_end]))


IndexError: tuple index out of range

In [None]:
print(f'num 0s: {(y==0).sum()}')
print(f'num 1s: {(y==1).sum()}')

num 0s: 44935
num 1s: 551
