In [1]:
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt
from math import ceil
from scipy.linalg import solve
from beamformers import beamformers as bf
import torch
import soundfile as sf
import IPython.display as ipd

In [2]:
# load some data from the test set
basedir = '/Data/DATASETS/WSJ/data_from_yi/tt/'
sample = 100
datadir = basedir + 'sample{}/'.format(sample)
nmics = 4

fs = 16000

def load_all_traces(sample):
    # spk1
    anec_1 = []
    for i in range(nmics):
        _mics, fs = sf.read(datadir + 'circular_anechoic_spk1_mic{}.wav'.format(i + 1))
        anec_1.append(_mics)
    anec_1 = np.array(anec_1)

    anec_2 = []
    for i in range(nmics):
        _mics, _ = sf.read(datadir + 'circular_anechoic_spk2_mic{}.wav'.format(i + 1))
        anec_2.append(_mics)
    anec_2 = np.array(anec_2)

    echo_1 = []
    for i in range(nmics):
        _mics, _ = sf.read(datadir + 'circular_echoic_spk1_mic{}.wav'.format(i + 1))
        echo_1.append(_mics)
    echo_1 = np.array(echo_1)

    echo_2 = []
    for i in range(nmics):
        _mics, _ = sf.read(datadir + 'circular_echoic_spk2_mic{}.wav'.format(i + 1))
        echo_2.append(_mics)
    echo_2 = np.array(echo_2)

    echo_noise = []
    for i in range(nmics):
        _mics, _ = sf.read(datadir + 'circular_echoic_noise_mic{}.wav'.format(i + 1))
        echo_noise.append(_mics)
    echo_noise = np.array(echo_noise)
    
    # 
    target = echo_1
    noise_sep = echo_2
    noise_enh = echo_noise
    mixture_sep = target + noise_sep
    mixture_enh = target + noise_enh
    
    return echo_1, echo_2, echo_noise, anec_1, anec_2

receptive_field = 0.128  # in s
frame_len = int(fs * receptive_field)
frame_step = int(frame_len / 4)
eps = 1e-15

In [3]:
fs = 16000
sample = 100
max_len = 4 * fs
echo_1, echo_2, echo_noise, anec_1, anec_2 = load_all_traces(sample)
        
target_1 = echo_1.astype('float32')[:, :max_len]
target_2 = echo_2.astype('float32')[:, :max_len]
noise_1_sep = echo_2.astype('float32')[:, :max_len]
noise_2_sep = echo_1.astype('float32')[:, :max_len]
noise_enh = echo_noise.astype('float32')[:, :max_len]

mixture_1_sep = target_1 + noise_1_sep
mixture_2_sep = target_2 + noise_2_sep
mixture_1_enh = target_1 + noise_enh
mixture_2_enh = target_2 + noise_enh

In [22]:
separated = bf.MB_MVDR_oracle(mixture_1_sep, noise_1_sep, target_1, mask="IBM")
observation = bf.stft(mixture_1_sep)
target_1_stft = bf.stft(target_1)
target_2_stft = bf.stft(target_2)
masks = bf.calculate_masks([target_1_stft, target_2_stft], mask="IBM")
mvdr_weights = bf.mb_mvdr_weights(observation, masks[1][0], masks[0][0])

t_mask_speech = torch.from_numpy(masks[0][0]).unsqueeze(0).unsqueeze(2).unsqueeze(2)
t_mask_noise = torch.from_numpy(masks[1][0]).unsqueeze(0).unsqueeze(2).unsqueeze(2)
pre_t_observation = torch.cat([torch.from_numpy(np.real(observation)).unsqueeze(-1), torch.from_numpy(np.imag(observation)).unsqueeze(-1)], -1).unsqueeze(0)

print("np mask1: {}".format(masks[0].shape))
print("np mask2: {}".format(masks[1].shape))
print("np obs: {}".format(mixture_1_sep.shape))
print("np obs: {}".format(mixture_1_sep.dtype))
print("np sep: {}".format(separated.shape))
print("")
print("t mask1: {}".format(t_mask_speech.shape))
print("t mask2: {}".format(t_mask_noise.shape))
print("t obs: {}".format(pre_t_observation.shape))
print("t obs: {}".format(pre_t_observation.dtype))

np mask1: (4, 1025, 98)
np mask2: (4, 1025, 98)
np obs: (4, 49243)
np obs: float32
np sep: (49243,)

t mask1: torch.Size([1, 1025, 1, 1, 98])
t mask2: torch.Size([1, 1025, 1, 1, 98])
t obs: torch.Size([1, 4, 1025, 98, 2])
t obs: torch.float32


In [20]:
def complex_multiply(x, y, conjugate=False):
    # x: (B, M, 2, F, T)
    a = x[:, :, 0].unsqueeze(1)
    b = x[:, :, 1].unsqueeze(1)
    c = y[:, :, 0].unsqueeze(1)
    d = y[:, :, 1].unsqueeze(1) * 1 if not conjugate else -1

    real = a * c - b * d
    imag = a * d + b * c
    return torch.cat([real, imag], 2)

def complex_divide(x, y):
    # x: (B, M, 2, F, T)
    num = complex_multiply(x, y, conjugate=True)  # (B, M, 2, F, T)
    den = complex_multiply(y, y, conjugate=True)[:, :, 0],unsqueeze(2)  # (B, M, 1, F, T)

    return num / den

def trace(x, dim1=-1, dim2=-2, keepdim=True):
    if x.shape[dim1] != x.shape[dim2]:
        raise ValueError("Matrix should be square")
    n = len(x.shape) - 2
    ones = torch.eye(x.shape[dim1])[(None,) * n]  # [..., dim1, dim2]
    if x.is_cuda:
        ones = ones.cuda()
    filt = ones * x
    return filt.sum((dim1, dim2), keepdim=keepdim)

def condition_covariance(x, gamma):
    """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)"""
    scale = gamma * trace(x, keepdim=True) / x.shape[-1]  # [...]
    n = len(x.shape) - 2
    eye = torch.eye(x.shape[-1])[(None,) * n]
    if x.is_cuda:
        eye = eye.cuda()
    scaled_eye = (eye * scale)
    return (x + scaled_eye) / (1 + gamma)

def einsum(a, b, s='...dt,...et->...de'):
    return torch.einsum(s, a, b)

def complex_psd(x, mask, normalize=True, condition=True, eps=1e-15):
    # x: (B, M, 2, F, T)
    masked = x * mask  # (B, M, 2, F, L)

    a = masked[:, :, 0].unsqueeze(2)  # re x
    b = masked[:, :, 1].unsqueeze(2)  # im x
    c = x[:, :, 0].unsqueeze(2)
    d = - x[:, :, 1].unsqueeze(2)  # im y has been conjugated

    real = einsum(a, c) - einsum(b, d)
    imag = einsum(a, d) + einsum(b, c)
    psd = torch.cat([real, imag], 2)
    
    if normalize:
        normalization = mask.sum(-1).unsqueeze(-1)
        psd /= normalization + eps
        
    return psd

def mse(x, y):
    return np.mean((np.abs(x - y))**2)

In [6]:
t_observation = t_observation.permute(0, 2, 4, 1, 3)  # (B, F, 2, M, L)   prev (B, M, F, L, 2)
back = t_observation.permute(0, 3, 1, 4, 2).squeeze().data.numpy()

print(mse(back[..., 0] + 1j * back[..., 1], observation))

0.0


In [15]:
# numpy psd
cov_noise = bf.get_power_spectral_density_matrix(observation.transpose(1, 0, 2), masks[1][0], normalize=False)
cov_noise = bf.condition_covariance(cov_noise, 1e-6)
cov_noise /= np.real(np.trace(cov_noise, axis1=-2, axis2=-1)[..., None, None])

print(np.real(np.trace(cov_noise, axis1=-2, axis2=-1)))
# torch
psd_noise = complex_psd(t_observation, t_mask_noise, normalize=False)  # (B, F, 2, M, M)
psd_noise = condition_covariance(psd_noise, 1e-6)
psd_noise /= trace(psd_noise, dim1=-1, dim2=-2, keepdim=True)[:, :, 0].unsqueeze(2)
# print(psd_noise[:, :, 1].squeeze().dtype)
print(trace(psd_noise, dim1=-1, dim2=-2, keepdim=False)[0, :, 0])

# error 
back = psd_noise.squeeze().data.numpy()
print(mse(back[:, 0] + 1j * back[:, 1], cov_noise))

[0.99999994 1.         0.9999999  ... 0.99999994 1.         0.99999994]
tensor([1., 1., 1.,  ..., 1., 1., 1.])
1.0331673e-15


In [36]:
n_mic = 4
t_observation = pre_t_observation.permute(0, 2, 4, 1, 3)  # (B, F, 2, M, L)
# calculate psds[(None,) * n]
psd_noise = complex_psd(t_observation, t_mask_noise, normalize=False)  # (B, F, 2, M, M)
psd_speech = complex_psd(t_observation, t_mask_speech, condition=True)  # (B, F, 2, M, M)
psd_noise = condition_covariance(psd_noise, 1e-6)
psd_noise /= trace(psd_noise, dim1=-1, dim2=-2, keepdim=True)[:, :, 0].unsqueeze(2)

# calculate weights
# speech A
# noise B
re_a = psd_speech[:, :, 0]
im_a = psd_speech[:, :, 1]
re_b = psd_noise[:, :, 0]
im_b = psd_noise[:, :, 1]

A = torch.cat([torch.cat([re_a, -im_a], -1), torch.cat([im_a, re_a], -1)], -2)
B = torch.cat([torch.cat([re_b, -im_b], -1), torch.cat([im_b, re_b], -1)], -2)
H, _ = torch.gesv(A, B)  # (B, F, 2, M, M)

trace_H = trace(H, keepdim=True)  # (B, F, 2, 1, 1)
H /= trace_H 

h_re, h_im = H[..., :n_mic, 0], H[..., n_mic:, 0]

# apply weights
a = h_re  # (B, F, M)
b = -h_im
c = t_observation[:, :, 0].permute(0, 2, 1, 3)  # (B, M, F, L)
d = t_observation[:, :, 1].permute(0, 2, 1, 3)  # (B, M, F, L)

real = einsum(a, c, '...ab,...bac->...ac') - einsum(b, d, '...ab,...bac->...ac')
imag = einsum(a, d, '...ab,...bac->...ac') + einsum(b, c, '...ab,...bac->...ac')  # (B, F, L)

filtered = torch.cat([real.unsqueeze(-1), imag.unsqueeze(-1)], -1).squeeze()  # B, F, T, 2

filtered_t = filtered.data.numpy()
filtered_n = filtered_t[:, :, 0] + 1j * filtered_t[:, :, 1]
filtered_n = bf.istft(filtered_n)[: len(separated)]
print("NP separated: {}".format(separated.shape))
print("T separated: {}".format(filtered_n.shape))

print(np.mean(np.abs((filtered_n - separated)**2)))
print(np.std(separated))

NP separated: (49243,)
T separated: (49243,)
1.240427e-08
0.00022263055


In [38]:
ipd.display(ipd.Audio(separated, rate=fs))
ipd.display(ipd.Audio(filtered_n , rate=fs))