Estimating the DoA and for any given time frame, taking averages across all frequencies to then suppress if outside a given threshold

In [None]:
import numpy as np
import soundfile as sf
from scipy.signal import stft, istft
from scipy.signal.windows import hann

def gcc_phat(sig, refsig, fs, max_tau=None, interp=16):
    n = sig.shape[0] + refsig.shape[0]
    SIG = np.fft.rfft(sig, n=n)
    REFSIG = np.fft.rfft(refsig, n=n)
    R = SIG * np.conj(REFSIG)
    cc = np.fft.irfft(R / (np.abs(R) + 1e-15), n=(interp * n))
    max_shift = int(interp * n / 2)
    if max_tau:
        max_shift = np.minimum(int(interp * fs * max_tau), max_shift)
    cc = np.concatenate((cc[-max_shift:], cc[:max_shift+1]))
    shift = np.argmax(np.abs(cc)) - max_shift
    tau = shift / float(interp * fs)
    return tau

def estimate_frame_doa_gcc(multi_frame, mic_positions, fs):
    pairs = [(0, 1), (0, 2), (0, 3)]  # All pairs w.r.t mic0
    toas = []
    for i, j in pairs:
        tau = gcc_phat(multi_frame[:, i], multi_frame[:, j], fs)
        toas.append(tau)
    toas = np.array(toas)

    # Solve for direction (simple least squares)
    A = []
    b = []
    mic0 = mic_positions[0]
    for idx, (i, j) in enumerate(pairs):
        delta = mic_positions[j] - mic_positions[i]
        A.append(delta)
        b.append(toas[idx] * 343)
    A = np.vstack(A)
    b = np.array(b)
    doa_est, *_ = np.linalg.lstsq(A, b, rcond=None)
    return doa_est / np.linalg.norm(doa_est)

def suppress_nonmatching_frames(audio_4ch, fs, doa_true, threshold_deg=20, frame_size=512):
    num_samples = audio_4ch.shape[0]
    num_frames = num_samples // frame_size
    mic_positions = np.array([
        [-0.05, 0.0, 0.0],
        [ 0.05, 0.0, 0.0],
        [-0.08, 0.045, 0.04],
        [ 0.08, 0.045, 0.04],
    ])
    output = np.zeros(num_samples)
    doa_true = doa_true / np.linalg.norm(doa_true)

    for i in range(num_frames):
        start = i * frame_size
        end = start + frame_size
        frame = audio_4ch[start:end, :]
        if frame.shape[0] < frame_size:
            break

        doa_est = estimate_frame_doa_gcc(frame, mic_positions, fs)
        angle_diff = np.degrees(np.arccos(np.clip(np.dot(doa_est, doa_true), -1, 1)))
        
        if angle_diff < threshold_deg:
            output[start:end] += 0.1 * frame[:, 0]  # Use one mic's channel (or average)
        else:
            output[start:end] +=  frame[:, 0]  # Quiet non-source

        if i % 10 == 0:
            print(f"Processed frame {i+1}/{num_frames} — angle diff = {angle_diff:.2f}°")

    return output

In [None]:
#same as above, but uses the closest of estimated DoA's for the frequency bins as reference

def estimate_binwise_doa_gcc(multi_frame, mic_positions, fs):
    num_mics = multi_frame.shape[1]
    frame_len = multi_frame.shape[0]
    _, _, Zxx = stft(multi_frame.T, fs=fs, window=hann(frame_len), nperseg=frame_len, noverlap=0, return_onesided=False, axis=-1)
    Zxx = np.transpose(Zxx, (2, 0, 1))  # shape: (n_bins, n_mics, 1)
    n_bins = Zxx.shape[0]

    doa_estimates = []
    for b in range(n_bins):
        bin_data = Zxx[b, :, 0]  # shape: (n_mics,)
        pairs = [(0, 1), (0, 2), (0, 3)]
        toas = []
        for i, j in pairs:
            sig = bin_data[i]
            refsig = bin_data[j]
            R = sig * np.conj(refsig)
            tau = np.angle(R) / (2 * np.pi * (b / frame_len) * fs + 1e-8)
            toas.append(tau)
        toas = np.array(toas)

        A = []
        b_vec = []
        for idx, (i, j) in enumerate(pairs):
            delta = mic_positions[j] - mic_positions[i]
            A.append(delta)
            b_vec.append(toas[idx] * 343)
        A = np.vstack(A)
        b_vec = np.array(b_vec)
        doa_est, *_ = np.linalg.lstsq(A, b_vec, rcond=None)
        norm = np.linalg.norm(doa_est)
        doa_estimates.append(doa_est / norm if norm != 0 else np.zeros(3))

    return np.array(doa_estimates)  # shape: (n_bins, 3)

def suppress_nonmatching_bins(audio_4ch, fs, doa_true, threshold_deg=20, frame_size=2048):
    num_samples = audio_4ch.shape[0]
    hop_size = frame_size // 2
    mic_positions = np.array([
        [-0.05, 0.0, 0.0],
        [ 0.05, 0.0, 0.0],
        [0.08, 0.045, 0.04],
        [-0.08, 0.045, 0.04],
    ])

    doa_true = doa_true / np.linalg.norm(doa_true)
    output = np.zeros(num_samples)
    window = hann(frame_size)

    for start in range(0, num_samples - frame_size, hop_size):
        frame = audio_4ch[start:start + frame_size, :]
        if frame.shape[0] < frame_size:
            break

        doa_binwise = estimate_binwise_doa_gcc(frame, mic_positions, fs)
        angle_diffs = np.degrees(np.arccos(np.clip(doa_binwise @ doa_true, -1, 1)))
        best_bin = np.argmin(angle_diffs)

        if angle_diffs[best_bin] < threshold_deg:
            avg = np.mean(frame, axis=1)
            output[start:start + frame_size] += window * avg
        else:
            output[start:start + frame_size] += 0.0

        print(f"Processed frame starting at sample {start} — closest bin angle diff = {angle_diffs[best_bin]:.2f}°")

    return output
