In [1]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
from scipy.linalg import eigh, LinAlgError

In [2]:
import torchaudio.transforms
import torch

In [3]:
### Loading data
noise_mask_0 = np.load('../model_results/model_output/noise_mask_mic0.npy')
target_mask_0 = np.load('../model_results/model_output/target_mask_mic0.npy')
noise_mask_1 = np.load('../model_results/model_output/noise_mask_mic1.npy')
target_mask_1 = np.load('../model_results/model_output/target_mask_mic1.npy')

#pred_spec_1_imag = np.load('../model_results/model_output/pred_spec_1_imag.npy')
#pred_spec_1_real = np.load('../model_results/model_output/pred_spec_1_real.npy')
#pred_spec_2_imag = np.load('../model_results/model_output/pred_spec_2_imag.npy')
#pred_spec_2_real = np.load('../model_results/model_output/pred_spec_2_real.npy')

mixture_spec_mic0 = np.load('../model_results/model_output/mixture_spec_mic0.npy')
mixture_spec_mic0 = mixture_spec_mic0[:,:256,:]
mixture_spec_mic1 = np.load('../model_results/model_output/mixture_spec_mic1.npy')
mixture_spec_mic1 = mixture_spec_mic1[:,:256,:]


In [4]:
# Print the shapes of the arrays
print(f"Shape of target_mask_0: {target_mask_0.shape}")
print(f"Shape of noise_mask_0: {noise_mask_0.shape}")
print(f"Shape of target_mask_1: {target_mask_1.shape}")
print(f"Shape of noise_mask_1: {noise_mask_1.shape}")

print(f"Shape of mixture_spec_mic0: {mixture_spec_mic0.shape}")
print(f"Shape of mixture_spec_mic1: {mixture_spec_mic1.shape}")

# If you want to see basic stats about the masks, you can do the following:
#print(f"Target Mask - Min: {np.min(target_mask)}, Max: {np.max(target_mask)}, Mean: {np.mean(target_mask)}, Std: {np.std(target_mask)}")
#print(f"Noise Mask - Min: {np.min(noise_mask)}, Max: {np.max(noise_mask)}, Mean: {np.mean(noise_mask)}, Std: {np.std(noise_mask)}")


Shape of target_mask_0: (1, 2, 256, 256)
Shape of noise_mask_0: (1, 2, 256, 256)
Shape of target_mask_1: (1, 2, 256, 256)
Shape of noise_mask_1: (1, 2, 256, 256)
Shape of mixture_spec_mic0: (2, 256, 256)
Shape of mixture_spec_mic1: (2, 256, 256)


In [5]:
#Prepare data for PSD calculation
# Convert to PyTorch tensors
mixture_spec_mic0_tensor = torch.from_numpy(mixture_spec_mic0)
mixture_spec_mic1_tensor = torch.from_numpy(mixture_spec_mic1)
#target_mask_0_tensor = target_mask_0_tensor[:,:,-1,:]
#target_mask_1_tensor = target_mask_1_tensor[:,:,-1,:]
target_mask_0_tensor = torch.from_numpy(target_mask_0)
target_mask_1_tensor = torch.from_numpy(target_mask_1)

noise_mask_0_tensor = torch.from_numpy(noise_mask_0)
noise_mask_1_tensor = torch.from_numpy(noise_mask_1)

spectrogram_tensor = torch.stack([mixture_spec_mic0_tensor, mixture_spec_mic1_tensor], dim=0)
print (spectrogram_tensor.dtype)
target_mask_tensor = torch.stack([target_mask_0_tensor,target_mask_1_tensor], dim=0) 
target_mask_tensor = target_mask_tensor[:,0,:,:]
noise_mask_tensor = torch.stack([noise_mask_0_tensor,noise_mask_1_tensor], dim=0) 
noise_mask_tensor = noise_mask_tensor[:,0,:,:]
 

torch.float64


In [6]:
print(f"Shape of spectrogram_tensor: {spectrogram_tensor.shape}")
print(f"Shape of target_mask_tensor: {target_mask_tensor.shape}")
print(f"Shape of target_mask_tensor: {noise_mask_tensor.shape}")
print (spectrogram_tensor.dtype)

Shape of spectrogram_tensor: torch.Size([2, 2, 256, 256])
Shape of target_mask_tensor: torch.Size([2, 2, 256, 256])
Shape of target_mask_tensor: torch.Size([2, 2, 256, 256])
torch.float64


In [7]:
#Compute PSD Matrices
psd_transform = torchaudio.transforms.PSD(multi_mask=True)

In [8]:
psd_matrix_target = psd_transform(spectrogram_tensor.unsqueeze(0), target_mask_tensor.unsqueeze(0))
psd_matrix_target = psd_matrix_target.squeeze(0)
psd_matrix_noise = psd_transform(spectrogram_tensor.unsqueeze(0), noise_mask_tensor.unsqueeze(0))
psd_matrix_noise = psd_matrix_noise.squeeze(0)
print (spectrogram_tensor.dtype)

torch.float64


In [9]:
print(f"Shape of psd_matrix_target: {psd_matrix_target.shape}")
print(f"Shape of psd_matrix_noise: {psd_matrix_noise.shape}")
print (spectrogram_tensor.dtype)

Shape of psd_matrix_target: torch.Size([2, 256, 2, 2])
Shape of psd_matrix_noise: torch.Size([2, 256, 2, 2])
torch.float64


In [10]:
def condition_covariance(x, gamma):
    """
    Stabilizes the covariance matrix by adding a scaled identity matrix.
    :param x: Covariance matrix (tensor)
    :param gamma: Regularization parameter
    :return: Regularized covariance matrix
    """
    identity_matrix = torch.eye(x.shape[-1], device=x.device, dtype=x.dtype)
    scale = gamma * torch.trace(x) / x.shape[-1]
    scaled_eye = identity_matrix * scale
    return (x + scaled_eye) / (1 + gamma)

def phase_correction(vector):
    """
    Phase correction to reduce distortions due to phase inconsistencies.
    Args:
        vector: Beamforming vector with shape (bins, sensors) as a complex tensor.
    Returns: Phase corrected beamforming vectors.
    """
    w = vector.clone()
    F, D = w.shape
    for f in range(1, F):
        phase_correction_factor = torch.exp(-1j * torch.angle(torch.sum(w[f, :] * w[f-1, :].conj())))
        w[f, :] *= phase_correction_factor
    return w

def get_gev_vector(target_psd_matrix, noise_psd_matrix, base_reg_param=1e-6):
    bins, num_channels, _ = noise_psd_matrix.shape
    beamforming_vector = torch.empty((bins, num_channels), dtype=torch.complex128, device=target_psd_matrix.device)
    for f in range(bins):
        regularization_success = False
        reg_param = base_reg_param
        while not regularization_success:
            try:
                # Apply regularization
                reg_identity = torch.eye(num_channels, dtype=target_psd_matrix.dtype, device=target_psd_matrix.device) * reg_param
                noise_psd_reg = noise_psd_matrix[f, :, :] + reg_identity

                # Attempt to compute the eigenvalues and eigenvectors
                eigenvals, eigenvecs = eigh(target_psd_matrix[f, :, :], noise_psd_reg)
                beamforming_vector[f, :] = eigenvecs[:, -1]
                regularization_success = True  # Exit loop if successful
            except RuntimeError as e:
                # Increase the regularization parameter if the matrix is still not positive definite
                reg_param *= 10
                if reg_param > 1e-3:  # Avoid excessively high regularization
                    print(f"Could not regularize matrix at bin {f} even with high regularization: {e}")
                    break

    return beamforming_vector

In [11]:
class GEVBeamformer:
    def __init__(self, gamma=1e-6):
        self.gamma = gamma

    def compute_psd_matrix(self, observation, mask=None, normalize=True):
        # Use the previously defined function
        return get_power_spectral_density_matrix(observation, mask, normalize)

    def condition_covariance(self, x):
        # Use the previously defined function
        return condition_covariance(x, self.gamma)  

    def get_beamforming_vector(self, target_psd_matrix, noise_psd_matrix):
        # First, condition the noise PSD matrix
        conditioned_noise_psd = self.condition_covariance(noise_psd_matrix)
        # Then, get the GEV beamforming vector
        return get_gev_vector(target_psd_matrix, conditioned_noise_psd)
    
    def gev_wrapper_on_masks(self, mix, target_psd, noise_psd, normalization=False):
        org_dtype = mix.dtype
        mix = mix.to(torch.cdouble)  # Convert to complex double precision
        mix = mix.transpose(0, 1)  # Transpose the matrix

        # Condition the noise covariance matrix
        noise_psd = condition_covariance(noise_psd, 1e-6)
    
        # Calculate the trace, reshaping appropriately
        noise_psd = torch.diagonal(noise_psd, dim1=-2, dim2=-1).sum(-1, keepdim=True).unsqueeze(-1)
        
        #print shapes of masks for debugging
        print (f'The shape of noise_psd is: {noise_psd.shape}')
        print (f'The shape of target_psd is: {target_psd.shape}')
        

        # Get the GEV vector
        W_gev = get_gev_vector(target_psd, noise_psd)

        # Apply phase correction
        W_gev = phase_correction(W_gev)

        # Apply normalization if specified
        if normalization:
            W_gev = blind_analytic_normalization(W_gev, noise_psd)

        # Apply the beamforming vector
        output = apply_beamforming_vector(W_gev, mix)

        # Convert the output back to the original data type and transpose back
        output = output.to(org_dtype).transpose(0, 1)

        return output
   

In [13]:
print (spectrogram_tensor.dtype)
spectrogram_tensor = spectrogram_tensor.to(torch.cdouble)  # Convert to complex double precision
spectrogram_tensor = spectrogram_tensor.transpose(0, 1)  # Transpose the matrix
print (spectrogram_tensor.dtype)
gev = GEVBeamformer()
gev.gev_wrapper_on_masks(spectrogram_tensor, psd_matrix_target,  psd_matrix_noise)

torch.complex128
torch.complex128


RuntimeError: trace: expected a matrix, but got tensor with dim 4