In [None]:
### functions to keep
def split_inner_product_mx(a_re, a_im, b_re, b_im):
    ''' calculates the inner product of two vectors a, b such that <a|b>
    takes seperately the real and imaginary parts of the complex arrays
    '''
    c_re = a_re * b_re - a_im * b_im
    c_im = a_re * b_im + b_re * a_im
    return c_re, c_im


def split_noise_weighted_inner_product_mx(a_re, a_im, b_re, b_im, PSD, duration):
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    takes seperately the real and imaginary parts of the complex arrays
    '''
    numerator_re, numerator_im = inner_product_mx(a_re, a_im, b_re, b_im)
    integrand_re = numerator_re / PSD
    integrand_im = numerator_im / PSD
    result_re = 4 / duration * sum(integrand_re)
    result_im = 4 / duration * sum(integrand_im)
    return result_re, result_im


def make_dtype_complex64_mx(array):
    ''' changes an arbitrary dtype mx array to complex64
    '''
    return mx.array.astype(array, dtype=mx.complex64)


def inner_product_mx(a, b): 
    ''' calculates the inner product of two vectors a, b such that <a|b>
    '''
    return mx.conjugate(a) * b


def complex_sum_mx(array):
    ''' since array.sum() doesn't play nicely with complex64, we do it by hand
    '''
    sum = mx.array(0)
    for iii in range(array.size):
        sum += array[iii]
    return sum


def noise_weighted_inner_product_mx(a, b, power_spectral_density, duration): 
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    integrand = mx.conjugate(a) * b / power_spectral_density
    return 4 / duration * complex_sum_mx(integrand)


def real_mx(array):
    ''' keeps the real part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return mx.array.astype(array, dtype=mx.float32)


def imag_mx(array):
    ''' keeps the imaginary part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return real_mx((real_mx(array)-array)*1j)


def decompose_complex_array_mx(array):
    ''' saves the real and imaginary parts of a complex64 array as float32 arrays
    '''
    return real_mx(array), imag_mx(array)


def rescale_PSD_mx(power_spectral_density, dynamic_range_factor):
    ''' rescaling the dynamic range of the PSD, forcing use of numpy as MLX
    does not play nicely with the large exponents
    '''
    return np.array(power_spectral_density) * dynamic_range_factor
