In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import eigh, LinAlgError
#Audio stuff
#from signal_processing import audiowrite, stft, istft
from IPython.display import Audio
import wave
import librosa
from scipy.io import wavfile
import scipy

In [3]:
import torchaudio.transforms
import torch
import torchaudio.functional as F

In [4]:
#Evaluation stuff
from pesq import pesq
from pystoi import stoi
import mir_eval

### Signal Processing

In [5]:
from numpy.fft import rfft, irfft
from scipy import signal
from scipy.io.wavfile import write as wav_write
import string
import threading
def _samples_to_stft_frames(samples, size, shift):
    """
    Calculates STFT frames from samples in time domain.
    :param samples: Number of samples in time domain.
    :param size: FFT size.
    :param shift: Hop in samples.
    :return: Number of STFT frames.
    """

    return np.ceil((samples - size + shift) / shift).astype(np.int64)


def segment_axis(a, length, overlap=0, axis=None, end='cut', endvalue=0):
    """Generate a new array that chops the given array along the given axis into overlapping frames.

    example:
    >>> segment_axis(np.arange(10), 4, 2)
    array([[0, 1, 2, 3],
           [2, 3, 4, 5],
           [4, 5, 6, 7],
           [6, 7, 8, 9]])

    arguments:
    a       The array to segment
    length  The length of each frame
    overlap The number of array elements by which the frames should overlap
    axis    The axis to operate on; if None, act on the flattened array
    end     What to do with the last frame, if the array is not evenly
            divisible into pieces. Options are:

            'cut'   Simply discard the extra values
            'wrap'  Copy values from the beginning of the array
            'pad'   Pad with a constant value

    endvalue    The value to use for end='pad'

    The array is not copied unless necessary (either because it is
    unevenly strided and being flattened or because end is set to
    'pad' or 'wrap').
    """

    if axis is None:
        a = np.ravel(a)  # may copy
        axis = 0

    l = a.shape[axis]

    if overlap >= length: raise ValueError(
            "frames cannot overlap by more than 100%")
    if overlap < 0 or length <= 0: raise ValueError(
            "overlap must be nonnegative and length must be positive")

    if l < length or (l - length) % (length - overlap):
        if l > length:
            roundup = length + (1 + (l - length) // (length - overlap)) * (
                length - overlap)
            rounddown = length + ((l - length) // (length - overlap)) * (
                length - overlap)
        else:
            roundup = length
            rounddown = 0
        assert rounddown < l < roundup
        assert roundup == rounddown + (length - overlap) or (
            roundup == length and rounddown == 0)
        a = a.swapaxes(-1, axis)

        if end == 'cut':
            a = a[..., :rounddown]
        elif end in ['pad', 'wrap']:  # copying will be necessary
            s = list(a.shape)
            s[-1] = roundup
            b = np.empty(s, dtype=a.dtype)
            b[..., :l] = a
            if end == 'pad':
                b[..., l:] = endvalue
            elif end == 'wrap':
                b[..., l:] = a[..., :roundup - l]
            a = b

        a = a.swapaxes(-1, axis)

    l = a.shape[axis]
    if l == 0: raise ValueError(
            "Not enough data points to segment array in 'cut' mode; "
            "try 'pad' or 'wrap'")
    assert l >= length
    assert (l - length) % (length - overlap) == 0
    n = 1 + (l - length) // (length - overlap)
    s = a.strides[axis]
    newshape = a.shape[:axis] + (n, length) + a.shape[axis + 1:]
    newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[
                                                                  axis + 1:]

    if not a.flags.contiguous:
        a = a.copy()
        newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[
                                                                      axis + 1:]
        return np.ndarray.__new__(np.ndarray, strides=newstrides,
                                  shape=newshape, buffer=a, dtype=a.dtype)

    try:
        return np.ndarray.__new__(np.ndarray, strides=newstrides,
                                  shape=newshape, buffer=a, dtype=a.dtype)
    except TypeError or ValueError:
        warnings.warn("Problem with ndarray creation forces copy.")
        a = a.copy()
        # Shape doesn't change but strides does
        newstrides = a.strides[:axis] + ((length - overlap) * s, s) + a.strides[
                                                                      axis + 1:]
        return np.ndarray.__new__(np.ndarray, strides=newstrides,
                                  shape=newshape, buffer=a, dtype=a.dtype)


def _stft_frames_to_samples(frames, size, shift):
    """
    Calculates samples in time domain from STFT frames
    :param frames: Number of STFT frames.
    :param size: FFT size.
    :param shift: Hop in samples.
    :return: Number of samples in time domain.
    """
    return frames * shift + size - shift


def stft(time_signal, time_dim=None, size=512, shift=256,
         window=signal.blackman, fading=True, window_length=None):
    """
    Calculates the short time Fourier transformation of a multi channel multi
    speaker time signal. It is able to add additional zeros for fade-in and
    fade out and should yield an STFT signal which allows perfect
    reconstruction.

    :param time_signal: multi channel time signal.
    :param time_dim: Scalar dim of time.
        Default: None means the biggest dimension
    :param size: Scalar FFT-size.
    :param shift: Scalar FFT-shift. Typically shift is a fraction of size.
    :param window: Window function handle.
    :param fading: Pads the signal with zeros for better reconstruction.
    :param window_length: Sometimes one desires to use a shorter window than
        the fft size. In that case, the window is padded with zeros.
        The default is to use the fft-size as a window size.
    :return: Single channel complex STFT signal
        with dimensions frames times size/2+1.
    """
    if time_dim is None:
        time_dim = np.argmax(time_signal.shape)

    # Pad with zeros to have enough samples for the window function to fade.
    if fading:
        pad = [(0, 0)] * time_signal.ndim
        pad[time_dim] = [size - shift, size - shift]
        time_signal = np.pad(time_signal, pad, mode='constant')

    # Pad with trailing zeros, to have an integral number of frames.
    frames = _samples_to_stft_frames(time_signal.shape[time_dim], size, shift)
    samples = _stft_frames_to_samples(frames, size, shift)
    pad = [(0, 0)] * time_signal.ndim
    pad[time_dim] = [0, samples - time_signal.shape[time_dim]]
    time_signal = np.pad(time_signal, pad, mode='constant')

    if window_length is None:
        window = window(size)
    else:
        window = window(window_length)
        window = np.pad(window, (0, size - window_length), mode='constant')

    time_signal_seg = segment_axis(time_signal, size,
                                   size - shift, axis=time_dim)

    letters = string.ascii_lowercase
    mapping = letters[:time_signal_seg.ndim] + ',' + letters[time_dim + 1] \
              + '->' + letters[:time_signal_seg.ndim]

    return rfft(np.einsum(mapping, time_signal_seg, window),
                axis=time_dim + 1)

def _biorthogonal_window_loopy(analysis_window, shift):
    """
    This version of the synthesis calculation is as close as possible to the
    Matlab impelementation in terms of variable names.

    The results are equal.

    The implementation follows equation A.92 in
    Krueger, A. Modellbasierte Merkmalsverbesserung zur robusten automatischen
    Spracherkennung in Gegenwart von Nachhall und Hintergrundstoerungen
    Paderborn, Universitaet Paderborn, Diss., 2011, 2011
    """
    fft_size = len(analysis_window)
    assert np.mod(fft_size, shift) == 0
    number_of_shifts = len(analysis_window) // shift

    sum_of_squares = np.zeros(shift)
    for synthesis_index in range(0, shift):
        for sample_index in range(0, number_of_shifts + 1):
            analysis_index = synthesis_index + sample_index * shift

            if analysis_index + 1 < fft_size:
                sum_of_squares[synthesis_index] \
                    += analysis_window[analysis_index] ** 2

    sum_of_squares = np.kron(np.ones(number_of_shifts), sum_of_squares)
    synthesis_window = analysis_window / sum_of_squares / fft_size
    return synthesis_window

def istft(stft_signal, size=512, shift=256,
          window=signal.blackman, fading=True, window_length=None):
    """
    Calculated the inverse short time Fourier transform to exactly reconstruct
    the time signal.

    :param stft_signal: Single channel complex STFT signal
        with dimensions frames times size/2+1.
    :param size: Scalar FFT-size.
    :param shift: Scalar FFT-shift. Typically shift is a fraction of size.
    :param window: Window function handle.
    :param fading: Removes the additional padding, if done during STFT.
    :param window_length: Sometimes one desires to use a shorter window than
        the fft size. In that case, the window is padded with zeros.
        The default is to use the fft-size as a window size.
    :return: Single channel complex STFT signal
    :return: Single channel time signal.
    """
    print(stft_signal.shape[1])
    assert stft_signal.shape[1] == size // 2 + 1

    if window_length is None:
        window = window(size)
    else:
        window = window(window_length)
        window = np.pad(window, (0, size - window_length), mode='constant')

    window = _biorthogonal_window_loopy(window, shift)

    # Why? Line created by Hai, Lukas does not know, why it exists.
    window *= size

    time_signal = scipy.zeros(stft_signal.shape[0] * shift + size - shift)

    for j, i in enumerate(range(0, len(time_signal) - size + shift, shift)):
        time_signal[i:i + size] += window * np.real(irfft(stft_signal[j]))

    # Compensate fade-in and fade-out
    if fading:
        time_signal = time_signal[
                      size - shift:len(time_signal) - (size - shift)]

    return time_signal

### Class Beamformer

In [6]:
class GEVBeamformer:
    def __init__(self, gamma=1e-6):
        self.gamma = gamma
        
    def condition_covariance(self, x, gamma):
        
        """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)"""
        
        scale = gamma * np.trace(x) / x.shape[-1]
        scaled_eye = np.eye(x.shape[-1]) * scale
        return (x + scaled_eye) / (1 + gamma)
    
    def phase_correction(self, vector):
        """Phase correction to reduce distortions due to phase inconsistencies
        Args:
        vector: Beamforming vector with shape (..., bins, sensors).
        Returns: Phase corrected beamforming vectors. Lengths remain.
        """
        
        w = vector.copy()
        F, D = w.shape
        for f in range(1, F):
            w[f, :] *= np.exp(-1j*np.angle(
                np.sum(w[f, :] * w[f-1, :].conj(), axis=-1, keepdims=True)))
        return w
    
    def get_gev_vector(self, target_psd_matrix, noise_psd_matrix):
        
        """
        Returns the GEV beamforming vector.
        :param target_psd_matrix: Target PSD matrix
        with shape (bins, sensors, sensors)
        :param noise_psd_matrix: Noise PSD matrix
        with shape (bins, sensors, sensors)
        :return: Set of beamforming vectors with shape (bins, sensors)
        """
        #target_psd_matrix = target_psd_matrix[0,:,:,:]
        target_psd_matrix = target_psd_matrix[0,0]
        print(target_psd_matrix.shape)
        #noise_psd_matrix = noise_psd_matrix[0,:,:,:]
        noise_psd_matrix = noise_psd_matrix[0,0]
        bins, sensors, _ = target_psd_matrix.shape
        beamforming_vector = np.empty((bins, sensors), dtype=np.complex128)
        for f in range(bins):
            try:
                eigenvals, eigenvecs = eigh(target_psd_matrix[f, :, :],
                                        noise_psd_matrix[f, :, :])
                beamforming_vector[f, :] = eigenvecs[:, -1]
            except np.linalg.LinAlgError:
                print('LinAlg error for frequency {}'.format(f))
                beamforming_vector[f, :] = (
                    np.ones((sensors,)) / np.trace(noise_psd_matrix[f]) * sensors)
        return beamforming_vector


    def get_beamforming_vector(self, target_psd_matrix, noise_psd_matrix):
        # First, condition the noise PSD matrix
        conditioned_noise_psd = self.condition_covariance(noise_psd_matrix)
        # Then, get the GEV beamforming vector
        return get_gev_vector(target_psd_matrix, conditioned_noise_psd)
    
    def apply_beamforming_vector(self, vector, mix):
        return np.einsum('...a,...at->...t', vector.conj(), mix)
    
    def gev_wrapper_on_masks(self, mix, target_psd_matrix, noise_psd_matrix,
                         normalization=False):
        org_dtype = mix.dtype
        mix = mix.astype(np.complex128)
        mix = mix.T
        noise_psd_matrix = self.condition_covariance(noise_psd_matrix, 1e-6)
        noise_psd_matrix /= np.trace(
            noise_psd_matrix, axis1=-2, axis2=-1)[..., None, None]
        W_gev = self.get_gev_vector(target_psd_matrix, noise_psd_matrix)
        print (f'Shape of GEV vector: {W_gev.shape}')
        W_gev = self.phase_correction(W_gev)
        print (f'Shape of GEV vector: {W_gev.shape}')
        if normalization:
            W_gev = blind_analytic_normalization(W_gev, noise_psd_matrix)
        print (f'Shape of GEV vector: {W_gev.shape}')
        output = self.apply_beamforming_vector(W_gev, mix)
        print (f'Shape of GEV vector: {output.shape}')
        #output = output.astype(org_dtype)
        print (f'Shape of GEV vector: {output.shape}')
        return output.T

In [7]:
#Loading audio data
target, sr = librosa.load("../test/id08696_f0gzvkyroAY_00388VSid04030_JbcD0P6KGe0_00039/00039_mic1.wav", sr=16000)  
noise, sr = librosa.load("../test/id08696_f0gzvkyroAY_00388VSid04030_JbcD0P6KGe0_00039/00388_mic1.wav", sr=16000)

In [8]:
##Trying out Dorothea suggestion
print(target.shape)
print(noise.shape)
A = np.array([[1, 1], [1, -1]])
signals = np.stack([target, noise])
# Apply the mixing matrix
mixed_signals = A @ signals

(64512,)
(64512,)


In [9]:
print(mixed_signals.shape)
display(Audio(mixed_signals[1], rate=16000))  # Replace 16000 with your sample rate

(2, 64512)


In [10]:
size = 512
size // 2 + 1

257

In [11]:
Y = stft(mixed_signals, time_dim=1).transpose((1, 0, 2))

  window = window(size)


In [12]:
X=istft(Y[:,1,:],512)


257


  window = window(size)
  time_signal = scipy.zeros(stft_signal.shape[0] * shift + size - shift)


In [13]:
display(Audio(X, rate=16000))

In [14]:
from chainer import Variable
Y_var = Variable(np.abs(Y).astype(np.float32))

In [29]:
def get_power_spectral_density_matrix(observation, mask=None, normalize=True):
    """
    Calculates the weighted power spectral density matrix.

    This does not yet work with more than one target mask.

    :param observation: Complex observations with shape (bins, sensors, frames)
    :param mask: Masks with shape (bins, frames) or (bins, 1, frames)
    :return: PSD matrix with shape (bins, sensors, sensors)
    """
    bins, sensors, frames = observation.shape

    if mask is None:
        mask = np.ones((bins, frames))
    if mask.ndim == 2:
        mask = mask[:, np.newaxis, :]

    psd = np.einsum('...dt,...et->...de', mask * observation,
                    observation.conj())
    if normalize:
        normalization = np.sum(mask, axis=-1, keepdims=True)
        psd /= normalization
    return psd


def condition_covariance(x, gamma):
    """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)"""
    scale = gamma * np.trace(x) / x.shape[-1]
    scaled_eye = np.eye(x.shape[-1]) * scale
    return (x + scaled_eye) / (1 + gamma)

def get_gev_vector(target_psd_matrix, noise_psd_matrix):
    """
    Returns the GEV beamforming vector.
    :param target_psd_matrix: Target PSD matrix
        with shape (bins, sensors, sensors)
    :param noise_psd_matrix: Noise PSD matrix
        with shape (bins, sensors, sensors)
    :return: Set of beamforming vectors with shape (bins, sensors)
    """
    bins, sensors, _ = target_psd_matrix.shape
    beamforming_vector = np.empty((bins, sensors), dtype=np.complex128)
    for f in range(bins):
        try:
            eigenvals, eigenvecs = eigh(target_psd_matrix[f, :, :],
                                        noise_psd_matrix[f, :, :])
            beamforming_vector[f, :] = eigenvecs[:, -1]
        except np.linalg.LinAlgError:
            print('LinAlg error for frequency {}'.format(f))
            beamforming_vector[f, :] = (
                np.ones((sensors,)) / np.trace(noise_psd_matrix[f]) * sensors
            )
    return beamforming_vector

def gev_wrapper_on_masks(mix, noise_mask=None, target_mask=None,
                         normalization=False):
    #if noise_mask is None and target_mask is None:
     #   raise ValueError('At least one mask needs to be present.')

    org_dtype = mix.dtype
    mix = mix.astype(np.complex128)
    mix = mix.T
    if noise_mask is not None:
        noise_mask = noise_mask.T
    if target_mask is not None:
        target_mask = target_mask.T

    target_psd_matrix = get_power_spectral_density_matrix(
        mix, target_mask, normalize=False)
    noise_psd_matrix = get_power_spectral_density_matrix(
        mix, noise_mask, normalize=True)
    noise_psd_matrix = condition_covariance(noise_psd_matrix, 1e-6)
    noise_psd_matrix /= np.trace(
        noise_psd_matrix, axis1=-2, axis2=-1)[..., None, None]
    W_gev = get_gev_vector(target_psd_matrix, noise_psd_matrix)
    W_gev = phase_correction(W_gev)

    if normalization:
        W_gev = blind_analytic_normalization(W_gev, noise_psd_matrix)

    output = apply_beamforming_vector(W_gev, mix)
    output = output.astype(org_dtype)

    return output.T

def phase_correction(vector):
        """Phase correction to reduce distortions due to phase inconsistencies
        Args:
        vector: Beamforming vector with shape (..., bins, sensors).
        Returns: Phase corrected beamforming vectors. Lengths remain.
        """
        
        w = vector.copy()
        F, D = w.shape
        for f in range(1, F):
            w[f, :] *= np.exp(-1j*np.angle(
                np.sum(w[f, :] * w[f-1, :].conj(), axis=-1, keepdims=True)))
        return w
    
def apply_beamforming_vector(vector, mix):
    return np.einsum('...a,...at->...t', vector.conj(), mix)

In [30]:
Y_hat = gev_wrapper_on_masks(Y)

In [31]:
Y_hat.shape

(253, 257)

In [34]:
X_hat=istft(Y_hat[:,:],512)

257


  window = window(size)
  time_signal = scipy.zeros(stft_signal.shape[0] * shift + size - shift)


In [35]:
display(Audio(X, rate=16000))