Using MVDR as the beamformer, utilizing some/different neural networks pretrained or otherwise, to predict masks for the frequency and time bins to finetune the calculation of spatial covariances for each time chunk.

The beamformer is made adaptive by including a forgetting factor that updates the spatial covariance as a function of the previous time chunks and the one calculated solely from the current time chunk. Beamformer weights are calculated individually for every such chunk then applied solely to the STFT values of that time period.

In [None]:
import numpy as np
import soundfile as sf
from scipy.signal import stft, istft, windows
import torch
import torch.nn as nn

audio_path = '/kaggle/input/audios/audio_dataset/audio_dataset/samsung_non_overlapping/02-25.12-20-54-168__WL_BH_d1m_Left_TC5.hdf/02-25.12-20-54-168__WL_BH_d1m_Left_TC5.wav'
audio, fs = sf.read(audio_path)  # Expect shape (n_samples, 4)
assert fs == 16000, f"Expected sampling rate 16000, but got {fs}"
assert audio.ndim == 2 and audio.shape[1] == 4, "Audio must be 4-channel"
audio = np.asarray(audio, dtype=np.float32)
print(f"Loaded audio with shape {audio.shape}, fs = {fs}")

In [None]:
mic_positions = np.array([
    [-0.05,  0.00,  0.00],  # Mic 1 (right)
    [ 0.05,  0.00,  0.00],  # Mic 2 (left)
    [-0.08,  0.045, 0.04],  # Mic 3 (upper-right-back)
    [ 0.08,  0.045, 0.04],  # Mic 4 (upper-left-back)
], dtype=np.float32)

source_pos = np.array([0.0, -0.06, 0.0], dtype=np.float32)

# Far-field noise DoA: example azimuth in horizontal (x-z) plane, 0° = front (-z)
# User can update doa_noise_deg per segment/frame if dynamic
doa_noise_deg = 90.0  
az = np.deg2rad(doa_noise_deg)
doa_noise_vec = np.array([np.sin(az), 0.0, -np.cos(az)], dtype=np.float32)
doa_noise_vec /= np.linalg.norm(doa_noise_vec)

print("Mic positions:\n", mic_positions)
print("Source position:", source_pos)
print(f"Noise DoA vector (az={doa_noise_deg}°):", doa_noise_vec)

In [None]:
n_fft = 512
hop_length = n_fft // 2
window = windows.hann(n_fft, sym=False)

# Compute STFT for each channel
stfts = []
for ch in range(4):
    f_bins, t_frames, Zxx = stft(
        audio[:, ch], fs=fs, window=window,
        nperseg=n_fft, noverlap=n_fft-hop_length,
        boundary=None, padded=False
    )
    stfts.append(Zxx)  # shape (F, T)
# Stack to shape (F, T, M)
stfts = np.stack(stfts, axis=2)  # (F_bins, T_frames, 4)
F_bins, T_frames, M = stfts.shape
assert M == 4
freqs = f_bins  # array of length F_bins
print(f"STFT computed: freq bins={F_bins}, time frames={T_frames}, channels={M}")

In [None]:
c = 343.0  # speed of sound

def steering_vector_nearfield(mic_positions, source_pos, freqs, speed_of_sound=343.0, include_amplitude=False):
    """
    Compute near-field steering vectors for each frequency.
    Returns array shape (F_bins, M).
    """
    diffs = mic_positions - source_pos[None, :]  # (M,3)
    dists = np.linalg.norm(diffs, axis=1)       # (M,)
    dists = np.maximum(dists, 1e-6)
    F = len(freqs)
    M = mic_positions.shape[0]
    a = np.zeros((F, M), dtype=np.complex64)
    for idx, f in enumerate(freqs):
        phase = np.exp(-1j * 2 * np.pi * f * dists / speed_of_sound)
        if include_amplitude:
            a[idx, :] = phase / dists
        else:
            a[idx, :] = phase
    return a

def steering_vector_farfield(mic_positions, doa_vec, freqs, speed_of_sound=343.0):
    """
    Compute far-field steering vectors for each frequency.
    Returns array shape (F_bins, M).
    """
    proj = mic_positions.dot(doa_vec)  # (M,)
    F = len(freqs)
    M = mic_positions.shape[0]
    a = np.zeros((F, M), dtype=np.complex64)
    for idx, f in enumerate(freqs):
        a[idx, :] = np.exp(-1j * 2 * np.pi * f * proj / speed_of_sound)
    return a

# Precompute steering vectors
a_s = steering_vector_nearfield(mic_positions, source_pos, freqs, speed_of_sound=c, include_amplitude=False)  # (F_bins, M)
a_n = steering_vector_farfield(mic_positions, doa_noise_vec, freqs, speed_of_sound=c)  # (F_bins, M)

In [None]:
# placeholder LSTM network, not currently tested, ideally require pretrained checkpoints

class MaskNet(nn.Module):
    def __init__(self, n_freq_bins, n_ipd_pairs):
        super(MaskNet, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1 + 2 * n_ipd_pairs, out_channels=16, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        self.blstm = nn.LSTM(input_size=16 * n_freq_bins, hidden_size=128, num_layers=1,
                             batch_first=True, bidirectional=True)
        self.fc = nn.Linear(128 * 2, n_freq_bins)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, mag, ipd):
        # mag: (batch, 1, T, F), ipd: (batch, 2*n_ipd_pairs, T, F)
        x = torch.cat([mag, ipd], dim=1)  # (batch, channels, T, F)
        x = self.conv(x)  # (batch, 16, T, F)
        b, c, T, F = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()  # (batch, T, 16, F)
        x = x.view(b, T, c * F)  # (batch, T, 16*F)
        y, _ = self.blstm(x)  # (batch, T, 2*hidden)
        mask = self.fc(y)  # (batch, T, F)
        mask = self.sigmoid(mask)
        return mask  # (batch, T, F)

# Instantiate MaskNet
n_freq_bins = F_bins
n_ipd_pairs = M - 1  # pairs relative to ref mic 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mask_net = MaskNet(n_freq_bins=n_freq_bins, n_ipd_pairs=n_ipd_pairs).to(device)
# TODO: load pretrained weights ex.:
# mask_net.load_state_dict(torch.load('masknet_checkpoint.pth'))
mask_net.eval()

In [None]:
stft_tensor = torch.from_numpy(np.transpose(stfts, (2, 1, 0))).unsqueeze(0)  # (1, M, T, F)

# Reference magnitude (mic 0)
mag = torch.abs(stft_tensor[:, 0:1, :, :])  # (1,1,T,F)

# IPD features between mic i and mic 0
eps = 1e-9
ipd_list = []
for i in range(1, M):
    Xi = stft_tensor[:, i, :, :].type(torch.complex64)  # (1, T, F)
    X0 = stft_tensor[:, 0, :, :].type(torch.complex64)
    cs = Xi * torch.conj(X0)  # (1, T, F)
    cs_norm = cs / (torch.abs(cs) + eps)
    ipd_list.append(torch.real(cs_norm).unsqueeze(1))  # (1,1,T,F)
    ipd_list.append(torch.imag(cs_norm).unsqueeze(1))
ipd = torch.cat(ipd_list, dim=1)  # (1, 2*(M-1), T, F)

# Predict mask (batch size 1)
with torch.no_grad():
    mag_dev = mag.to(device)
    ipd_dev = ipd.to(device)
    mask_pred = mask_net(mag_dev, ipd_dev)  # (1, T, F)
mask_pred = mask_pred.squeeze(0).cpu().numpy()  # (T_frames, F_bins)
print("Predicted noise mask shape:", mask_pred.shape)

In [None]:
alpha = 0.95      # forgetting factor, adjust for adaptivity vs stability
diag_loading = 1e-6
eps = 1e-9

# Initialize recursive noise covariance estimates R_n_est[f] as identity
R_n_est = np.array([np.eye(M, dtype=np.complex64) for _ in range(F_bins)])  # (F_bins, M, M)

# Allocate output STFT array
Y = np.zeros((F_bins, T_frames), dtype=np.complex64)

for t in range(T_frames):
    X_t = stfts[:, t, :]           # (F_bins, M)
    mask_t = mask_pred[t, :]       # (F_bins,)
    # Update covariance per frequency
    for f in range(F_bins):
        x = X_t[f, :].reshape(M, 1)  # (M,1)
        R_inst = mask_t[f] * (x @ np.conj(x.T))  # (M,M)
        # Recursive update
        R_n_est[f] = alpha * R_n_est[f] + (1 - alpha) * R_inst
        # Diagonal loading
        R_n_est[f] += diag_loading * np.eye(M, dtype=np.complex64)
    # Compute MVDR weights per frequency
    W_t = np.zeros((F_bins, M), dtype=np.complex64)
    for f in range(F_bins):
        Rf = R_n_est[f]            # (M,M)
        a_s_k = a_s[f, :]          # (M,)
        # Invert Rf
        try:
            Rf_inv = np.linalg.inv(Rf)
        except np.linalg.LinAlgError:
            Rf_inv = np.linalg.pinv(Rf)
        denom = np.vdot(a_s_k, Rf_inv @ a_s_k)
        if np.abs(denom) < eps:
            w = np.zeros(M, dtype=np.complex64)
        else:
            w = (Rf_inv @ a_s_k) / denom  # MVDR weight
        W_t[f, :] = w
    # Apply beamforming for frame t
    # Y[f,t] = w(f,t).H @ X(f,t)
    Y[:, t] = np.sum(np.conj(W_t) * X_t, axis=1)

In [None]:
_, output = istft(Y, fs=fs, window=window,
                 nperseg=n_fft, noverlap=n_fft-hop_length,
                 input_onesided=True, boundary=None)
output = output[:audio.shape[0]]
# Normalize to avoid clipping
max_val = np.max(np.abs(output)) + 1e-9
if max_val > 1.0:
    output_norm = output / max_val
else:
    output_norm = output

output_path = 'adaptive_mvdr_output.wav'
sf.write(output_path, output_norm, fs)
print(f"Adaptive MVDR output written to {output_path}")