In [1]:
import numpy as np
import matplotlib.pyplot as plt
import librosa
from scipy.linalg import eigh, LinAlgError

In [2]:
import torchaudio.transforms
import torch

In [3]:
### Loading data
noise_mask_0 = np.load('../model_results/model_output/noise_mask_mic0.npy')
target_mask_0 = np.load('../model_results/model_output/target_mask_mic0.npy')
noise_mask_1 = np.load('../model_results/model_output/noise_mask_mic1.npy')
target_mask_1 = np.load('../model_results/model_output/target_mask_mic1.npy')

#pred_spec_1_imag = np.load('../model_results/model_output/pred_spec_1_imag.npy')
#pred_spec_1_real = np.load('../model_results/model_output/pred_spec_1_real.npy')
#pred_spec_2_imag = np.load('../model_results/model_output/pred_spec_2_imag.npy')
#pred_spec_2_real = np.load('../model_results/model_output/pred_spec_2_real.npy')

mixture_spec_mic0 = np.load('../model_results/model_output/mixture_spec_mic0.npy')
mixture_spec_mic0 = mixture_spec_mic0[:,:256,:]
mixture_spec_mic1 = np.load('../model_results/model_output/mixture_spec_mic1.npy')
mixture_spec_mic1 = mixture_spec_mic1[:,:256,:]


FileNotFoundError: [Errno 2] No such file or directory: '../model_results/model_output/noise_mask_mic0.npy'

In [None]:
# Print the shapes of the arrays
print(f"Shape of target_mask_0: {target_mask_0.shape}")
print(f"Shape of noise_mask_0: {noise_mask_0.shape}")
print(f"Shape of target_mask_1: {target_mask_1.shape}")
print(f"Shape of noise_mask_1: {noise_mask_1.shape}")

print(f"Shape of mixture_spec_mic0: {mixture_spec_mic0.shape}")
print(f"Shape of mixture_spec_mic1: {mixture_spec_mic1.shape}")

# If you want to see basic stats about the masks, you can do the following:
#print(f"Target Mask - Min: {np.min(target_mask)}, Max: {np.max(target_mask)}, Mean: {np.mean(target_mask)}, Std: {np.std(target_mask)}")
#print(f"Noise Mask - Min: {np.min(noise_mask)}, Max: {np.max(noise_mask)}, Mean: {np.mean(noise_mask)}, Std: {np.std(noise_mask)}")


In [None]:
#Prepare data for PSD calculation
# Convert to PyTorch tensors
mixture_spec_mic0_tensor = torch.from_numpy(mixture_spec_mic0)
mixture_spec_mic1_tensor = torch.from_numpy(mixture_spec_mic1)
#target_mask_0_tensor = target_mask_0_tensor[:,:,-1,:]
#target_mask_1_tensor = target_mask_1_tensor[:,:,-1,:]
target_mask_0_tensor = torch.from_numpy(target_mask_0)
target_mask_1_tensor = torch.from_numpy(target_mask_1)

noise_mask_0_tensor = torch.from_numpy(noise_mask_0)
noise_mask_1_tensor = torch.from_numpy(noise_mask_1)

spectrogram_tensor = torch.stack([mixture_spec_mic0_tensor, mixture_spec_mic1_tensor], dim=0)
print (spectrogram_tensor.dtype)
target_mask_tensor = torch.stack([target_mask_0_tensor,target_mask_1_tensor], dim=0) 
target_mask_tensor = target_mask_tensor[:,0,:,:]
noise_mask_tensor = torch.stack([noise_mask_0_tensor,noise_mask_1_tensor], dim=0) 
noise_mask_tensor = noise_mask_tensor[:,0,:,:]
 

In [None]:
print(f"Shape of spectrogram_tensor: {spectrogram_tensor.shape}")
print(f"Shape of target_mask_tensor: {target_mask_tensor.shape}")
print(f"Shape of target_mask_tensor: {noise_mask_tensor.shape}")
print (spectrogram_tensor.dtype)

In [None]:
#Compute PSD Matrices
psd_transform = torchaudio.transforms.PSD(multi_mask=True)

In [None]:
psd_matrix_target = psd_transform(spectrogram_tensor.unsqueeze(0), target_mask_tensor.unsqueeze(0))
psd_matrix_target = psd_matrix_target.squeeze(0)
psd_matrix_noise = psd_transform(spectrogram_tensor.unsqueeze(0), noise_mask_tensor.unsqueeze(0))
psd_matrix_noise = psd_matrix_noise.squeeze(0)
print (spectrogram_tensor.dtype)

In [None]:
#Back to numpy 
spectrogram_array = spectrogram_tensor.detach().cpu().numpy()
target_mask_array = target_mask_tensor.detach().cpu().numpy()
noise_mask_array = noise_mask_tensor.detach().cpu().numpy()
psd_matrix_noise = psd_matrix_noise.detach().cpu().numpy()
psd_matrix_target = psd_matrix_target.detach().cpu().numpy()

In [None]:
print(f"Shape of psd_matrix_target: {psd_matrix_target.shape}")
print(f"Shape of psd_matrix_noise: {psd_matrix_noise.shape}")
print (spectrogram_array.dtype)

In [None]:
class GEVBeamformer:
    def __init__(self, gamma=1e-6):
        self.gamma = gamma
        
    def condition_covariance(self, x, gamma):
        
        """see https://stt.msu.edu/users/mauryaas/Ashwini_JPEN.pdf (2.3)"""
        
        scale = gamma * np.trace(x) / x.shape[-1]
        scaled_eye = np.eye(x.shape[-1]) * scale
        return (x + scaled_eye) / (1 + gamma)
    
    def phase_correction(self, vector):
        """Phase correction to reduce distortions due to phase inconsistencies
        Args:
        vector: Beamforming vector with shape (..., bins, sensors).
        Returns: Phase corrected beamforming vectors. Lengths remain.
        """
        
        w = vector.copy()
        F, D = w.shape
        for f in range(1, F):
            w[f, :] *= np.exp(-1j*np.angle(
                np.sum(w[f, :] * w[f-1, :].conj(), axis=-1, keepdims=True)))
        return w
    
    def get_gev_vector(self, target_psd_matrix, noise_psd_matrix):
        
        """
        Returns the GEV beamforming vector.
        :param target_psd_matrix: Target PSD matrix
        with shape (bins, sensors, sensors)
        :param noise_psd_matrix: Noise PSD matrix
        with shape (bins, sensors, sensors)
        :return: Set of beamforming vectors with shape (bins, sensors)
        """
        target_psd_matrix = target_psd_matrix[0,:,:,:]
        noise_psd_matrix = noise_psd_matrix[0,:,:,:]
        bins, sensors, _ = target_psd_matrix.shape
        beamforming_vector = np.empty((bins, sensors), dtype=np.complex128)
        for f in range(bins):
            try:
                eigenvals, eigenvecs = eigh(target_psd_matrix[f, :, :],
                                        noise_psd_matrix[f, :, :])
                beamforming_vector[f, :] = eigenvecs[:, -1]
            except np.linalg.LinAlgError:
                print('LinAlg error for frequency {}'.format(f))
                beamforming_vector[f, :] = (
                    np.ones((sensors,)) / np.trace(noise_psd_matrix[f]) * sensors)
        return beamforming_vector


    def get_beamforming_vector(self, target_psd_matrix, noise_psd_matrix):
        # First, condition the noise PSD matrix
        conditioned_noise_psd = self.condition_covariance(noise_psd_matrix)
        # Then, get the GEV beamforming vector
        return get_gev_vector(target_psd_matrix, conditioned_noise_psd)
    
    def apply_beamforming_vector(self, vector, mix):
        return np.einsum('...a,...at->...t', vector.conj(), mix)
    
    def gev_wrapper_on_masks(self, mix, target_psd_matrix, noise_psd_matrix,
                         normalization=False):
        org_dtype = mix.dtype
        mix = mix.astype(np.complex128)
        mix = mix.T
        noise_psd_matrix = self.condition_covariance(noise_psd_matrix, 1e-6)
        noise_psd_matrix /= np.trace(
            noise_psd_matrix, axis1=-2, axis2=-1)[..., None, None]
        W_gev = self.get_gev_vector(target_psd_matrix, noise_psd_matrix)
        print (f'Shape of GEV vector: {W_gev.shape}')
        W_gev = self.phase_correction(W_gev)
        
        if normalization:
            W_gev = blind_analytic_normalization(W_gev, noise_psd_matrix)
        output = self.apply_beamforming_vector(W_gev, mix)
        output = output.astype(org_dtype)
        
        return output.T
   

In [None]:
#apply Beamforming
print (spectrogram_array.dtype)
gev_beamformer = GEVBeamformer()
gev = gev_beamformer.gev_wrapper_on_masks(spectrogram_array, 
                         psd_matrix_target, 
                         psd_matrix_noise)
np.save('gev.np',gev)

In [None]:
y, sr = librosa.load("C:/Users/yosra/Documents/AV-speech-separation/data/simulated_RIR/VoxCeleb2/raw_audio/id04030/JbcD0P6KGe0/00039_mic0_voice0.wav", sr=16000)  # sr=None ensures the original sampling rate is used

# Calculate the duration in seconds
duration = librosa.get_duration(y=y, sr=sr)
context_samples = duration * 16000

In [None]:
def istft_reconstruction(mag, phase, hop_length=160, win_length=400, length=65535):
    spec = mag.astype(np.complex128) * np.exp(1j*phase)
    wav = librosa.istft(spec, hop_length=hop_length, win_length=win_length, length=length)
    return np.clip(wav, -1., 1.)

In [None]:
from signal_processing import audiowrite, stft, istft
from IPython.display import Audio

In [None]:
print(f'The GEV vector type is: {gev.dtype}')
print(f'The GEV vector shape is: {gev.shape}')
phase = spectrogram_array[1,:,:,:]
print(f'The phase shape is: {phase.shape}')
wav =istft_reconstruction(gev, phase, hop_length=160, win_length=400, length=65535)
#time_signal = istft(gev, size=512, shift=128, window_length=400)
filename = 'ay_rabi_m3ak.wav'
# Write the audio data to a WAV file
#audiowrite(audio_time_domain, filename, 16000, True, True)

In [None]:
# Play the audio
display(Audio(wav, rate=16000))  # Replace 16000 with your sample rate

In [None]:
def plot_waveform(waveform=None, sampling_rate=1000, title="Waveform Plot", x_label="Time (s)", y_label="Amplitude", theme="light"):
    """
    Plots a waveform with customization options.

    Parameters:
    - waveform: ndarray, optional
        The waveform data to plot. If None, a default sine wave is generated.
    - sampling_rate: int, optional
        The sampling rate of the waveform in Hz.
    - title: str, optional
        The title of the plot.
    - x_label: str, optional
        The label for the x-axis.
    - y_label: str, optional
        The label for the y-axis.
    - theme: str, optional
        The theme of the plot, "light" or "dark".

    Returns:
    - None
    """

    # Create a default sine waveform if none is provided
    if waveform is None:
        t = np.linspace(0, 1, sampling_rate)
        waveform = np.sin(2 * np.pi * 5 * t)  # 5 Hz sine wave
    else:
        t = np.arange(len(waveform)) / sampling_rate

    # Set theme
    if theme == "dark":
        plt.style.use("dark_background")
    else:
        plt.style.use("seaborn-whitegrid")

    # Plot the waveform
    plt.figure(figsize=(10, 5))
    plt.plot(t, waveform, linewidth=2)
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_waveform(waveform=wav, sampling_rate=16000, title="Waveform Plot", x_label="Time (s)", y_label="Amplitude", theme="dark")