In [1]:
import numpy as np
import scipy as sp
from numpy import ndarray, nan

import ctypes

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm

from IPython.display import display, Audio

from functools import cache

In [2]:
%load_ext autoreload
%autoreload 2
%load_ext cython

In [None]:
@cache
def libDUET() -> ctypes.CDLL:
    lib = ctypes.cdll.LoadLibrary('libDUET.dylib')
    lib.duet_init.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_float, ctypes.c_float, ctypes.c_float]
    lib.duet_init.restype = ctypes.c_int
    lib.duet_deinit.restype = None
    lib.duet_reset.restype = None
    lib.duet_get_n_channels.restype = ctypes.c_int
    lib.duet_get_n_freq.restype = ctypes.c_int
    lib.duet_get_n_time.restype = ctypes.c_int
    lib.duet_get_audio_frame_size.restype = ctypes.c_int
    lib.duet_process_audio_frame.argtypes = [ctypes.POINTER(ctypes.c_int16), ctypes.POINTER(ctypes.c_int)]
    lib.duet_process_audio_frame.restype = ctypes.POINTER(ctypes.c_float * 2)
    return lib

def duet_init(
    window_size: int = 256, # size of the STFT window, must be a power of 2 between 8 and 8192 (inclusive)
    n_samples: int = 1536,   # number of audio samples to use for identification, must be a multiple of window_size/2
    p: float = 1.0,         # symmetric attenuation estimator value weight
    q: float = 0.0,         # delay estimator value weight
    point_threshold: float = 0.5 # point threshold for filtering
) -> None:
    """
    Initializes the Duet audio processing library.
    This function must be called before using any other functions in the library.
    It initializes the FFT library, precomputes numerous values, and sets up the
    necessary parameters for the DUET algorithm along with allocating memory for
    various buffers.

    Parameters
    ----------
    window_size : int
        Size of the STFT window, must be a power of 2 between 8 and 8192 (inclusive)
    n_samples : int
        Number of audio samples to use for identification, must be a multiple of window_size/2
    p : float
        Symmetric attenuation estimator value weight
    q : float
        Delay estimator value weight
    point_threshold : float
        Point threshold for filtering
    """
    retval = libDUET().duet_init(window_size, n_samples, p, q, point_threshold)
    if retval != 0:
        raise RuntimeError(f"duet_init failed with error code {retval}")

def duet_deinit() -> None:
    """
    Deinitialize the DUET audio processing library. This frees any resources
    that were allocated by duet_init() and resets the state of the library.

    This must be called to change the DUET parameters.
    """
    libDUET().duet_deinit()

def duet_reset() -> None:
    """
    Reset the DUET state for completely new, unrelated, audio. Does not
    deallocate or deinitialize any buffers or precomputed values.

    Typically this is called when switching to a new audio source but while
    keeping the same DUET parameters.
    """
    libDUET().duet_reset()

def duet_get_n_channels() -> int:
    """Returns the number of audio channels. Hardcoded to 2 for stereo audio."""
    return libDUET().duet_get_n_channels()

def duet_get_n_freq() -> int:
    """Returns the number of frequency bins in the STFT. Equal to `window_size / 2`."""
    return libDUET().duet_get_n_freq()

def duet_get_n_time() -> int:
    """Returns the number of active audio frames used for identification. Equal to `n_samples / (window_size / 2) + 1`."""
    return libDUET().duet_get_n_time()

def duet_get_audio_frame_size() -> int:
    """Returns the audio frame size. Equal to `window_size / 2 * decimation` (where `decimation` is the downsampling factor, e.g. 3)."""
    return libDUET().duet_get_audio_frame_size()

def duet_process_audio_frame(frame: np.ndarray) -> np.ndarray:
    """
    Add the new audio frame to the existing audio buffer and process it with
    DUET. The new audio frame is interleaved channel data with audio frame size
    samples for each channel (see `duet_get_audio_frame_size()`). Overall size
    must be exactly `n_channels * audio_frame_size` samples.

    This returns the demixed sources of shape (n_sources, N_FREQ, N_TIME) i.e.
    for each source there is a mono-spectrogram across all active audio frames.

    Due to incremental processing, it is recommended to call this N_TIME times
    before attempting to interpret the results.

    Parameters
    ----------
    frame : np.ndarray
        A 1D numpy array of int16 audio samples. The length of the array must be
        equal to the audio frame size times the number of channels.

    Returns
    -------
    demixed : np.ndarray
        A 3D numpy array of shape (n_sources, N_FREQ, N_TIME) containing the
        demixed audio sources as complex-valued spectrograms.
    """
    # validate input
    if not isinstance(frame, np.ndarray):
        raise ValueError("Input frame must be a numpy array")
    if frame.dtype != np.int16:
        raise ValueError("Input frame must be of dtype int16")
    if frame.ndim != 1:
        raise ValueError("Input frame must be a 1D numpy array")
    if frame.size != duet_get_audio_frame_size() * duet_get_n_channels():
        raise ValueError(f"Input frame must have exactly {duet_get_audio_frame_size() * duet_get_n_channels()} samples")
    if not frame.flags['C_CONTIGUOUS']:
        print("Warning: Input frame is not C contiguous, making a copy")
        frame = np.ascontiguousarray(frame)

    # call C function
    n_sources_ct = ctypes.c_int()
    print(n_sources_ct)
    result_ctypes = libDUET().duet_process_audio_frame(
        frame.ctypes.data_as(ctypes.POINTER(ctypes.c_int16)),
        ctypes.pointer(n_sources_ct)
    )
    print(n_sources_ct.value, result_ctypes)
    if not result_ctypes:
        raise RuntimeError("duet_process_audio_frame failed")

    # convert to numpy array
    n_sources = n_sources_ct.value
    n_freq = duet_get_n_freq()
    n_time = duet_get_n_time()
    if n_sources <= 0:
        return np.empty((0, n_freq, n_time), dtype=np.complex64)

    # TODO: convert to complex64? this line may need work
    return np.ctypeslib.as_array(result_ctypes, (n_sources, n_freq, n_time))


In [4]:
duet_init(window_size=256, n_samples=1536, point_threshold=0.5)
duet_get_n_channels(), duet_get_n_freq(), duet_get_n_time(), duet_get_audio_frame_size()

(2, 128, 13, 384)

In [None]:
# This should not change anything until deinit is called
#duet_deinit()
duet_init(window_size=512, n_samples=1536, point_threshold=0.5)
duet_get_n_channels(), duet_get_n_freq(), duet_get_n_time(), duet_get_audio_frame_size()

(2, 128, 13, 384)

In [9]:
duet_reset()

In [8]:
duet_process_audio_frame(np.zeros(duet_get_audio_frame_size() * duet_get_n_channels(), dtype=np.int16))

c_int(0)
0 <__main__.LP_c_float_Array_2 object at 0x3020788c0>


array([], shape=(0, 128, 13), dtype=complex64)