In [57]:
import sys

sys.path.insert(0, '..')

In [58]:
from __future__ import annotations
import os
import math
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict
import json

import numpy as np
from scipy import signal
from scipy.fft import fft, fftshift
import matplotlib.pyplot as plt
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB, TimeMasking, FrequencyMasking

from params import sample_rate, windowed_signal_length, num_mel_bands, overlap

ImportError: cannot import name 'overlap' from 'params' (c:\Users\afkhe\Programming\Projects\Machine Learning\voice-activity-detector\params.py)

In [None]:
class MelSpecPipeline(torch.nn.Module):
    def __init__(self, n_fft=windowed_signal_length, sample_rate=sample_rate, n_mel=num_mel_bands):
        super().__init__()
        self.mel_spec = MelSpectrogram(sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mel, power=2)

    def forward(self, wave):
        assert wave.shape[0] == 1

        mel_spec = self.mel_spec(wave)
        return mel_spec
    
pipeline = MelSpecPipeline()

In [None]:
def check_audio_metadata(metadata):
    assert metadata.sample_rate == 16000
    assert metadata.num_channels == 1
    assert metadata.num_frames > 0

def createDataFromRecording(session_root, id):
    wav_path = session_root + "/session_" + str(id) + "_mixture.wav"
    json_path = session_root + "/session_" + str(id) + ".json"

    # check some metadata
    metadata = torchaudio.info(wav_path)
    check_audio_metadata(metadata)
    print(f'Metadata: {metadata}')

    # retrieve speech segments
    speech_segments = set()
    with open(json_path, 'r') as f:
        speech_info = json.load(f)
    for key in speech_info:
        if key.isdigit():
            for info in speech_info[key]:
                segment = (info["start"], info["stop"])
                assert segment[0] < segment[1]
                speech_segments.add(segment)
    print(speech_segments)

    # MFSC pipeline
    wave, _ = torchaudio.load(wav_path)
    mels = pipeline(wave)
    mels.squeeze_(0)
    # librosa.display.specshow(mels.numpy())
    print(f'Shape of mels: {mels.shape}')

    mel_length_time = (windowed_signal_length * num_mel_bands) / sample_rate
    print(f'mel_length_time: {mel_length_time}')

    num_data = mels.shape[1] // num_mel_bands * overlap
    print(f'num_data: {num_data}')

    X = [torch.ones(num_mel_bands, num_mel_bands)] * num_data
    y = [0.0] * num_data

    # we iterate over all complete mel bands possible. 
    # We drop the end bit of the recording if it isn't long enough for a whole MFSC spectrogram
    # We process with overlaps of 1/`overlap`
    for output_i, mel_i in enumerate(range(0, int(mels.shape[1] - num_mel_bands), num_mel_bands // overlap)):
        mel_start_time = mel_i * mel_length_time
        mel_end_time = (mel_i + 1) * mel_length_time

        for start, end in speech_segments:
            if start < mel_start_time < end or start < mel_end_time < end:
                y[output_i] = 1.0
        X[output_i] = mels[:, mel_i : mel_i + num_mel_bands]
        print(f'X[{output_i}].shape: {X[output_i].shape}')

        # assert X[output_i].shape == (num_mel_bands, num_mel_bands)
    
createDataFromRecording(session_root="LibriParty/dataset/train/session_0", id=0)

Metadata: AudioMetaData(sample_rate=16000, num_frames=4783692, num_channels=1, bits_per_sample=32, encoding=PCM_F)
{(124.63, 136.975), (72.876, 82.036), (194.072, 205.542), (158.155, 162.49), (0.582, 16.477), (233.761, 246.866), (14.198, 25.438), (262.443, 270.543), (90.981, 106.051), (273.033, 288.833), (52.268, 68.272), (137.211, 151.031), (250.985, 265.71), (273.62, 287.805), (123.585, 134.32), (98.198, 112.692), (207.596, 223.196), (208.436, 220.976), (234.433, 249.023)}
Shape of mels: torch.Size([40, 18687])
mel_length_time: 1.28


NameError: name 'overlap' is not defined