In [None]:
### functions to keep
def split_inner_product_mx(a_re, a_im, b_re, b_im):
    ''' calculates the inner product of two vectors a, b such that <a|b>
    '''
    c_re = a_re * b_re - a_im * b_im
    c_im = a_re * b_im + b_re * a_im
    return c_re, c_im

def split_noise_weighted_inner_product_mx(a_re, a_im, b_re, b_im, PSD, duration):
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    numerator_re, numerator_im = inner_product_mx(a_re, a_im, b_re, b_im)
    integrand_re = numerator_re / PSD
    integrand_im = numerator_im / PSD
    result_re = 4 / duration * sum(integrand_re)
    result_im = 4 / duration * sum(integrand_im)
    return result_re, result_im

def real_mx_DTYPE_HACK(array):
    ''' this is a silly way to extract the real part of a complex number
    we calculate the real, transform it into a string, split the string to recover 
    the characters that constitute the real part, and transform that into a float
    dtype: complex64 -> float32
    '''
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            real = (z + mx.conjugate(z))/2
            real_str1 = str(real).split('(')[1].split('+')[0]
            if real_str1[0] == '-':
                real_str2 = str('-') + real_str1.split('-')[1]
            else: 
                real_str2 = real_str1.split('-')[0]
            container.append(float(real_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

def imag_mx_DTYPE_HACK(array):
    ''' this is a silly way to extract the imaginary part of a complex number
    in the same way as real_mx_DTYPE_HACK
    dtype: complex64 -> float32
    '''
    if str(array.dtype) == 'mlx.core.complex64':
        container = []
        for iii in range(array.size):
            z = array[iii]
            imag = (z - mx.conjugate(z))/2j
            imag_str1 = str(imag).split('(')[1].split(',')[0]
            if imag_str1[-1] == 'j':
                imag_str2 = str(imag_str1).split('j')[0].split('+')[0]
            else: 
                imag_str2 = '-' + str(imag_str1).split('j')[0].split('-')[1]
            container.append(float(imag_str2))
            container_mx = mx.array(container)
        return container_mx
    else: 
        return array

def make_dtype_complex64_mx(array):
    return mx.array.astype(array, dtype=mx.complex64)

def inner_product_mx(a, b): 
    return mx.conjugate(a) * b

def complex_sum_mx(array):
    ''' since array.sum() doesn't play nicely with complex64, we do it by hand
    '''
    sum = mx.array(0)
    for iii in range(array.size):
        sum += array[iii]
    return sum

def noise_weighted_inner_product_mx(a, b, PSD, duration): 
    ''' from Bilby directly: 
    integrand = np.conj(aa) * bb / power_spectral_density
    return 4 / duration * np.sum(integrand) 
    this also follows the PyCBC convention, the conjugated vector is given first
    '''
    integrand = mx.conjugate(a) * b / PSD
    return 4 / duration * complex_sum_mx(integrand)
    
def real_mx(array):
    ''' keeps the real part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return mx.array.astype(array, dtype=mx.float32)

def imag_mx(array):
    ''' keeps the imaginary part of a complex64 array by changing type
    dtype: complex64 -> float32
    '''
    return real_mx((real_mx(array)-array)*1j)

def decompose_complex_array_mx(array):
    ''' saves the real and imaginary parts of a complex64 array as float32 arrays
    '''
    return real_mx(array), imag_mx(array)